diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f18ce75863557f81cafb3154a0b0bdccfab3be1b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +github_page/control.pdf filter=lfs diff=lfs merge=lfs -text +github_page/p10.png filter=lfs diff=lfs merge=lfs -text +github_page/p11.png filter=lfs diff=lfs merge=lfs -text +github_page/p12.png filter=lfs diff=lfs merge=lfs -text +github_page/p13.png filter=lfs diff=lfs merge=lfs -text +github_page/p14.png filter=lfs diff=lfs merge=lfs -text +github_page/p15.png filter=lfs diff=lfs merge=lfs -text +github_page/p16b.png filter=lfs diff=lfs merge=lfs -text +github_page/p17.png filter=lfs diff=lfs merge=lfs -text +github_page/p18.png filter=lfs diff=lfs merge=lfs -text +github_page/p19.png filter=lfs diff=lfs merge=lfs -text +github_page/p2.png filter=lfs diff=lfs merge=lfs -text +github_page/p20.png filter=lfs diff=lfs merge=lfs -text +github_page/p21.png filter=lfs diff=lfs merge=lfs -text +github_page/p3.png filter=lfs diff=lfs merge=lfs -text +github_page/p4.png filter=lfs diff=lfs merge=lfs -text +github_page/p5.png filter=lfs diff=lfs merge=lfs -text +github_page/p6.png filter=lfs diff=lfs merge=lfs -text +github_page/p7.png filter=lfs diff=lfs merge=lfs -text +github_page/p8.png filter=lfs diff=lfs merge=lfs -text +github_page/p9.png filter=lfs diff=lfs merge=lfs -text +github_page/t/op.png filter=lfs diff=lfs merge=lfs -text +github_page/uc2a.png filter=lfs diff=lfs merge=lfs -text +github_page/uc2b.png filter=lfs diff=lfs merge=lfs -text +github_page/uc3.png filter=lfs diff=lfs merge=lfs -text +github_page/uc4.png filter=lfs diff=lfs merge=lfs -text +github_page/uc6.png filter=lfs diff=lfs merge=lfs -text +github_page/uci1.png filter=lfs diff=lfs merge=lfs -text +github_page/uci2.png filter=lfs diff=lfs merge=lfs -text +github_page/uci3.png filter=lfs diff=lfs merge=lfs -text +github_page/uci4.png filter=lfs diff=lfs merge=lfs -text +output-images/rgb2.png filter=lfs diff=lfs merge=lfs -text +test_imgs/bird.png filter=lfs diff=lfs merge=lfs -text +test_imgs/building.png filter=lfs diff=lfs merge=lfs -text +test_imgs/building2.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..325e5fd26480f88769892b2c263e91daca4ba719 --- /dev/null +++ b/.gitignore @@ -0,0 +1,143 @@ +.idea/ + +training/ +lightning_logs/ +image_log/ + +*.pth +*.pt +*.ckpt +*.safetensors + +gradio_pose2image_private.py +gradio_canny2image_private.py + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 4a970c5f4a764a53c10ea9c06adf1604108a7df0..2bab0977eb779ba65c1ab85cb52a8d190e831f20 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,354 @@ --- title: HumanSD -emoji: 👀 -colorFrom: red -colorTo: yellow +app_file: gradio_humanpose2image.py sdk: gradio -sdk_version: 3.45.1 -app_file: app.py -pinned: false +sdk_version: 3.44.3 --- +# News: A nightly version of ControlNet 1.1 is released! -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +[ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) is released. Those new models will be merged to this repo after we make sure that everything is good. + +# Below is ControlNet 1.0 + +Official implementation of [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543). + +ControlNet is a neural network structure to control diffusion models by adding extra conditions. + +![img](github_page/he.png) + +It copys the weights of neural network blocks into a "locked" copy and a "trainable" copy. + +The "trainable" one learns your condition. The "locked" one preserves your model. + +Thanks to this, training with small dataset of image pairs will not destroy the production-ready diffusion models. + +The "zero convolution" is 1×1 convolution with both weight and bias initialized as zeros. + +Before training, all zero convolutions output zeros, and ControlNet will not cause any distortion. + +No layer is trained from scratch. You are still fine-tuning. Your original model is safe. + +This allows training on small-scale or even personal devices. + +This is also friendly to merge/replacement/offsetting of models/weights/blocks/layers. + +### FAQ + +**Q:** But wait, if the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works? + +**A:** This is not true. [See an explanation here](docs/faq.md). + +# Stable Diffusion + ControlNet + +By repeating the above simple structure 14 times, we can control stable diffusion in this way: + +![img](github_page/sd.png) + +In this way, the ControlNet can **reuse** the SD encoder as a **deep, strong, robust, and powerful backbone** to learn diverse controls. Many evidences (like [this](https://jerryxu.net/ODISE/) and [this](https://vpd.ivg-research.xyz/)) validate that the SD encoder is an excellent backbone. + +Note that the way we connect layers is computational efficient. The original SD encoder does not need to store gradients (the locked original SD Encoder Block 1234 and Middle). The required GPU memory is not much larger than original SD, although many layers are added. Great! + +# Features & News + +2023/0/14 - We released [ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly). Those new models will be merged to this repo after we make sure that everything is good. + +2023/03/03 - We released a discussion - [Precomputed ControlNet: Speed up ControlNet by 45%, but is it necessary?](https://github.com/lllyasviel/ControlNet/discussions/216) + +2023/02/26 - We released a blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188) + +2023/02/20 - Implementation for non-prompt mode released. See also [Guess Mode / Non-Prompt Mode](#guess-anchor). + +2023/02/12 - Now you can play with any community model by [Transferring the ControlNet](https://github.com/lllyasviel/ControlNet/discussions/12). + +2023/02/11 - [Low VRAM mode](docs/low_vram.md) is added. Please use this mode if you are using 8GB GPU(s) or if you want larger batch size. + +# Production-Ready Pretrained Models + +First create a new conda environment + + conda env create -f environment.yaml + conda activate control + +All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on. + +We provide 9 Gradio apps with these models. + +All test images can be found at the folder "test_imgs". + +## ControlNet with Canny Edge + +Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection) + + python gradio_canny2image.py + +The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details. + +Prompt: "bird" +![p](github_page/p1.png) + +Prompt: "cute dog" +![p](github_page/p2.png) + +## ControlNet with M-LSD Lines + +Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection) + + python gradio_hough2image.py + +The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details. + +Prompt: "room" +![p](github_page/p3.png) + +Prompt: "building" +![p](github_page/p4.png) + +## ControlNet with HED Boundary + +Stable Diffusion 1.5 + ControlNet (using soft HED Boundary) + + python gradio_hed2image.py + +The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details. + +Prompt: "oil painting of handsome old man, masterpiece" +![p](github_page/p5.png) + +Prompt: "Cyberpunk robot" +![p](github_page/p6.png) + +## ControlNet with User Scribbles + +Stable Diffusion 1.5 + ControlNet (using Scribbles) + + python gradio_scribble2image.py + +Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio. + +Prompt: "turtle" +![p](github_page/p7.png) + +Prompt: "hot air balloon" +![p](github_page/p8.png) + +### Interactive Interface + +We actually provide an interactive interface + + python gradio_scribble2image_interactive.py + +~~However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.~~ (Now fixed, will update asap) + +The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase. + +Prompt: "dog in a room" +![p](github_page/p20.png) + +## ControlNet with Fake Scribbles + +Stable Diffusion 1.5 + ControlNet (using fake scribbles) + + python gradio_fake_scribble2image.py + +Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images. + +Prompt: "bag" +![p](github_page/p9.png) + +Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still seems to work.) +![p](github_page/p10.png) + +## ControlNet with Human Pose + +Stable Diffusion 1.5 + ControlNet (using human pose) + + python gradio_pose2image.py + +Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you. + +Prompt: "Chief in the kitchen" +![p](github_page/p11.png) + +Prompt: "An astronaut on the moon" +![p](github_page/p12.png) + +## ControlNet with Semantic Segmentation + +Stable Diffusion 1.5 + ControlNet (using semantic segmentation) + + python gradio_seg2image.py + +This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details. + +Prompt: "House" +![p](github_page/p13.png) + +Prompt: "River" +![p](github_page/p14.png) + +## ControlNet with Depth + +Stable Diffusion 1.5 + ControlNet (using depth map) + + python gradio_depth2image.py + +Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2). + +Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map. + +This is always a strength because if users do not want to preserve more details, they can simply use another SD to post-process an i2i. But if they want to preserve more details, ControlNet becomes their only choice. Again, SD2 uses 64×64 depth, we use 512×512. + +Prompt: "Stormtrooper's lecture" +![p](github_page/p15.png) + +## ControlNet with Normal Map + +Stable Diffusion 1.5 + ControlNet (using normal map) + + python gradio_normal2image.py + +This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling). + +Prompt: "Cute toy" +![p](github_page/p17.png) + +Prompt: "Plaster statue of Abraham Lincoln" +![p](github_page/p18.png) + +Compared to depth model, this model seems to be a bit better at preserving the geometry. This is intuitive: minor details are not salient in depth maps, but are salient in normal maps. Below is the depth result with same inputs. You can see that the hairstyle of the man in the input image is modified by depth model, but preserved by the normal model. + +Prompt: "Plaster statue of Abraham Lincoln" +![p](github_page/p19.png) + +## ControlNet with Anime Line Drawing + +We also trained a relatively simple ControlNet for anime line drawings. This tool may be useful for artistic creations. (Although the image details in the results is a bit modified, since it still diffuse latent images.) + +This model is not available right now. We need to evaluate the potential risks before releasing this model. Nevertheless, you may be interested in [transferring the ControlNet to any community model](https://github.com/lllyasviel/ControlNet/discussions/12). + +![p](github_page/p21.png) + + + +# Guess Mode / Non-Prompt Mode + +The "guess mode" (or called non-prompt mode) will completely unleash all the power of the very powerful ControlNet encoder. + +See also the blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188) + +You need to manually check the "Guess Mode" toggle to enable this mode. + +In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts. + +**Let's have fun with some very challenging experimental settings!** + +**No prompts. No "positive" prompts. No "negative" prompts. No extra caption detector. One single diffusion loop.** + +For this mode, we recommend to use 50 steps and guidance scale between 3 and 5. + +![p](github_page/uc2a.png) + +No prompts: + +![p](github_page/uc2b.png) + +Note that the below example is 768×768. No prompts. No "positive" prompts. No "negative" prompts. + +![p](github_page/uc1.png) + +By tuning the parameters, you can get some very intereting results like below: + +![p](github_page/uc3.png) + +Because no prompt is available, the ControlNet encoder will "guess" what is in the control map. Sometimes the guess result is really interesting. Because diffusion algorithm can essentially give multiple results, the ControlNet seems able to give multiple guesses, like this: + +![p](github_page/uc4.png) + +Without prompt, the HED seems good at generating images look like paintings when the control strength is relatively low: + +![p](github_page/uc6.png) + +The Guess Mode is also supported in [WebUI Plugin](https://github.com/Mikubill/sd-webui-controlnet): + +![p](github_page/uci1.png) + +No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce. + +![p](github_page/uci2.png) + +Below is another challenging example: + +![p](github_page/uci3.png) + +No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce. + +![p](github_page/uci4.png) + +Note that in the guess mode, you will still be able to input prompts. The only difference is that the model will "try harder" to guess what is in the control map even if you do not provide the prompt. Just try it yourself! + +Besides, if you write some scripts (like BLIP) to generate image captions from the "guess mode" images, and then use the generated captions as prompts to diffuse again, you will get a SOTA pipeline for fully automatic conditional image generating. + +# Combining Multiple ControlNets + +ControlNets are composable: more than one ControlNet can be easily composed to multi-condition control. + +Right now this feature is in experimental stage in the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet): + +![p](github_page/multi2.png) + +![p](github_page/multi.png) + +As long as the models are controlling the same SD, the "boundary" between different research projects does not even exist. This plugin also allows different methods to work together! + +# Use ControlNet in Any Community Model (SD1.X) + +This is an experimental feature. + +[See the steps here](https://github.com/lllyasviel/ControlNet/discussions/12). + +Or you may want to use the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) which is plug-and-play and does not need manual merging. + +# Annotate Your Own Data + +We provide simple python scripts to process images. + +[See a gradio example here](docs/annotator.md). + +# Train with Your Own Data + +Training a ControlNet is as easy as (or even easier than) training a simple pix2pix. + +[See the steps here](docs/train.md). + +# Related Resources + +Special Thank to the great project - [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) ! + +We also thank Hysts for making [Hugging Face Space](https://huggingface.co/spaces/hysts/ControlNet) as well as more than 65 models in that amazing [Colab list](https://github.com/camenduru/controlnet-colab)! + +Thank haofanwang for making [ControlNet-for-Diffusers](https://github.com/haofanwang/ControlNet-for-Diffusers)! + +We also thank all authors for making Controlnet DEMOs, including but not limited to [fffiloni](https://huggingface.co/spaces/fffiloni/ControlNet-Video), [other-model](https://huggingface.co/spaces/hysts/ControlNet-with-other-models), [ThereforeGames](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/7784), [RamAnanth1](https://huggingface.co/spaces/RamAnanth1/ControlNet), etc! + +Besides, you may also want to read these amazing related works: + +[Composer: Creative and Controllable Image Synthesis with Composable Conditions](https://github.com/damo-vilab/composer): A much bigger model to control diffusion! + +[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://github.com/TencentARC/T2I-Adapter): A much smaller model to control stable diffusion! + +[ControlLoRA: A Light Neural Network To Control Stable Diffusion Spatial Information](https://github.com/HighCWu/ControlLoRA): Implement Controlnet using LORA! + +And these amazing recent projects: [InstructPix2Pix Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix), [Pix2pix-zero: Zero-shot Image-to-Image Translation](https://github.com/pix2pixzero/pix2pix-zero), [Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation](https://github.com/MichalGeyer/plug-and-play), [MaskSketch: Unpaired Structure-guided Masked Image Generation](https://arxiv.org/abs/2302.05496), [SEGA: Instructing Diffusion using Semantic Dimensions](https://arxiv.org/abs/2301.12247), [Universal Guidance for Diffusion Models](https://github.com/arpitbansal297/Universal-Guided-Diffusion), [Region-Aware Diffusion for Zero-shot Text-driven Image Editing](https://github.com/haha-lisa/RDM-Region-Aware-Diffusion-Model), [Domain Expansion of Image Generators](https://arxiv.org/abs/2301.05225), [Image Mixer](https://twitter.com/LambdaAPI/status/1626327289288957956), [MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://multidiffusion.github.io/) + +# Citation + + @misc{zhang2023adding, + title={Adding Conditional Control to Text-to-Image Diffusion Models}, + author={Lvmin Zhang and Anyi Rao and Maneesh Agrawala}, + booktitle={IEEE International Conference on Computer Vision (ICCV)} + year={2023}, + } + +[Arxiv Link](https://arxiv.org/abs/2302.05543) + +[Supplementary Materials](https://lllyasviel.github.io/misc/202309/cnet_supp.pdf) diff --git a/__pycache__/aagenerator.cpython-38.pyc b/__pycache__/aagenerator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9849d946ed6ffc8e681519d81f42162c2d1755b9 Binary files /dev/null and b/__pycache__/aagenerator.cpython-38.pyc differ diff --git a/__pycache__/config.cpython-38.pyc b/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4b3684bf5087f62eda612c93506e51ccbd85195 Binary files /dev/null and b/__pycache__/config.cpython-38.pyc differ diff --git a/__pycache__/share.cpython-38.pyc b/__pycache__/share.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a7eb816a3213506c903db1c7013205ab17a2724 Binary files /dev/null and b/__pycache__/share.cpython-38.pyc differ diff --git a/aa-pose-inference.py b/aa-pose-inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bc13efe2228a0bee05289ab004173ab6a4433f94 --- /dev/null +++ b/aa-pose-inference.py @@ -0,0 +1,993 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import AutoTokenizer, PretrainedConfig +from transformers.utils import ContextManagers +from PIL import Image +import PIL +from PIL import ImageFile + +import diffusers +from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from openclip.training.data import get_wds_dataset, get_wds_dataset_cond +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, \ + DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, \ + PNDMScheduler, LMSDiscreteScheduler, UniPCMultistepScheduler + +from models.embedder import Embedder +from pipelines.pipeline_stable_diffusion_mb_downup import StableDiffusionPipeline +from collections import OrderedDict +import boto3 +from diffusers.models.controlnet_composer import ControlNetModel +# from pipelines.pipeline_controlnet_composer import StableDiffusionControlNetPipeline +from pipelines.pipeline_controlnet_composer_sdxl import StableDiffusionXLControlNetPipeline +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info +import json +import cv2 +import seaborn as sns + +if is_wandb_available(): + import wandb + +logger = get_logger(__name__, log_level="INFO") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0.dev0") + + +def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10): + humansd_skeleton = [ + [0, 0, 1], + [1, 0, 2], + [2, 1, 3], + [3, 2, 4], + [4, 3, 5], + [5, 4, 6], + [6, 5, 7], + [7, 6, 8], + [8, 7, 9], + [9, 8, 10], + [10, 5, 11], + [11, 6, 12], + [12, 11, 13], + [13, 12, 14], + [14, 13, 15], + [15, 14, 16], + ] + # humansd_skeleton_width=10 + humansd_color = sns.color_palette("hls", len(humansd_skeleton)) + + def plot_kpts(img_draw, kpts, color, edgs, width): + for idx, kpta, kptb in edgs: + if kpts[kpta, 2] > mmpose_detection_thresh and \ + kpts[kptb, 2] > mmpose_detection_thresh: + line_color = tuple([int(255 * color_i) for color_i in color[idx]]) + + cv2.line(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), (int(kpts[kptb, 0]), int(kpts[kptb, 1])), + line_color, width) + cv2.circle(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), width // 2, line_color, -1) + cv2.circle(img_draw, (int(kpts[kptb, 0]), int(kpts[kptb, 1])), width // 2, line_color, -1) + + if image is None: + pose_image = np.zeros((height, width, 3), dtype=np.uint8) + else: + pose_image = np.array(image, dtype=np.uint8) + for person_i in range(len(pose)): + if np.sum(pose[person_i]) > 0: + plot_kpts(pose_image, pose[person_i], humansd_color, humansd_skeleton, humansd_skeleton_width) + + return pose_image + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + ################################### newly added args ################################### + parser.add_argument("--ref_path", type=str, default="/data_laion/alvin/Dataset/evaluation/debug/42361.png") + parser.add_argument("--prompt", type=str, default="A person riding skis down a snow covered slope.") + parser.add_argument("--t2mn_path", type=str, + default="/data_laion/alvin/sd4human/a-ranstart-body-sdv20-v-nd-flaw-avg-copy1-glc-resume288k-512-ft1024/checkpoint-388000") + parser.add_argument("--controlnet_model_name_or_path", type=str, + default="/data_laion/alvin/sd4human/ctrl-sdxl10-eps-glc-composer-bmn-sum-1024/checkpoint-91000") + parser.add_argument('--step_num1', default=50, type=int) + parser.add_argument('--step_num2', default=50, type=int) + parser.add_argument('--size', default=2048, type=int) + parser.add_argument("--pretrained_vae_model_name_or_path", type=str, + default='/fsx_laion/alvin/pretrain/sdxl-vae-fp16-fix') + parser.add_argument('--normalize_dist', default=True, action="store_false") + parser.add_argument('--change_whole_to_body', default=True, action="store_false") + parser.add_argument('--off_wa', default=True, action="store_false") + parser.add_argument('--flaw', default=True, action="store_false") + parser.add_argument("--enable_xformers_memory_efficient_attention", default=True, action="store_false", + help="Whether or not to use xformers.") + # statistics for three datasets, laion+coyo+getty + parser.add_argument("--rgb_mean", type=float, default=0.14654) + parser.add_argument("--rgb_std", type=float, default=1.03744) + # parser.add_argument("--whole_mean", type=float, default=0.14713) + # parser.add_argument("--whole_std", type=float, default=0.96812) + parser.add_argument("--whole_mean", type=float, default=-0.2599426086956522) + parser.add_argument("--whole_std", type=float, default=1.3836632689065582) + parser.add_argument("--body_mean", type=float, default=-0.2481) + parser.add_argument("--body_std", type=float, default=1.45647) + parser.add_argument("--depth_mean", type=float, default=0.21360) + parser.add_argument("--depth_std", type=float, default=1.20629) + parser.add_argument("--normal_mean", type=float, default=0.60303) + parser.add_argument("--normal_std", type=float, default=0.91429) + + # # statistics for two datasetsm laion+coyo + # parser.add_argument("--rgb_mean", type=float, default=0.144028) + # parser.add_argument("--rgb_std", type=float, default=1.0420677550094796) + # parser.add_argument("--whole_mean", type=float, default=-0.2598586666666667) + # parser.add_argument("--whole_std", type=float, default=1.3824869261991977) + # parser.add_argument("--body_mean", type=float, default=-0.2481) + # parser.add_argument("--body_std", type=float, default=1.45647) + # parser.add_argument("--depth_mean", type=float, default=0.22104533333333334) + # parser.add_argument("--depth_std", type=float, default=1.2044201368629092) + # parser.add_argument("--normal_mean", type=float, default=0.6173293333333333) + # parser.add_argument("--normal_std", type=float, default=0.9108628719489077) + parser.add_argument('--start', default=0, type=int) + parser.add_argument('--end', default=8236, type=int) + parser.add_argument("--pretrained_model_name_or_path", type=str, + default='/fsx_laion/alvin/pretrain/stable-diffusion-2-base') + parser.add_argument("--pretrained_model_name_or_path2", type=str, + default='/fsx_laion/alvin/pretrain/stable-diffusion-xl-base-1.0') + parser.add_argument('--prediction_type', type=str, default='v_prediction', + choices=['epsilon', 'v_prediction', 'target'], help='Select a mode') + parser.add_argument('--prediction_type2', type=str, default='epsilon', + choices=['epsilon', 'v_prediction', 'target'], help='Select a mode') + parser.add_argument("--cond_num", type=int, default=3) + parser.add_argument('--fusion', type=str, default="sum") + parser.add_argument("--validation_steps", type=int, default=500, ) + parser.add_argument("--test_data_dir", nargs='+', type=str, default=None, ) + parser.add_argument('--filter_lowres', default=False, action="store_true") + parser.add_argument("--filter_res", type=int) + parser.add_argument('--noisy_cond', type=str, default=[], nargs="+", help='add which types of conditions') + parser.add_argument("--output_dir2", type=str, default="sd-model-finetuned") + parser.add_argument('--cond_reshape2', type=str, choices=['resize', 'vae', 'learn_conv'], + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--inference_folder_name2', type=str, + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--cond_inject2', type=str, choices=['concat', 'spade', 'sum'], + help='how to inject the spatial condition') + parser.add_argument('--cond_type2', type=str, default=[], nargs="+", help='add which types of conditions') + parser.add_argument('--cond_type_test2', type=str, default=None, nargs="+", help='add which types of conditions') + parser.add_argument("--resume_from_checkpoint2", type=str, default=None) + parser.add_argument('--pred_cond2', default=False, action="store_true") + parser.add_argument('--save_cond2', default=False, action="store_true") + parser.add_argument('--inference_folder_name', type=str, + default="/data_laion/yli12/code_new/ControlNet/output-images", + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--grid_dnc', default=False, action="store_true") + parser.add_argument('--pred_cond', default=False, action="store_true") + parser.add_argument('--save_cond', default=False, action="store_true") + parser.add_argument('--cond_reshape', type=str, choices=['resize', 'vae', 'learn_conv'], + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--cond_inject', type=str, choices=['concat', 'spade', 'sum'], + help='how to inject the spatial condition') + parser.add_argument('--cond_type', type=str, default=["body", "midas_depth", "normal"], nargs="+", + help='add which types of conditions') + parser.add_argument('--cond_type_test', type=str, default=None, nargs="+", help='add which types of conditions') + parser.add_argument("--embedder_channel", default=4, type=int, help="channel number.") + ################################### newly added args ################################### + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + nargs='+', + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=7, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # # Sanity checks + # if args.dataset_name is None and args.train_data_dir is None: + # raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def main(): + args = parse_args() + if args.change_whole_to_body: + args.whole_mean = args.body_mean + args.whole_std = args.body_std + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + # logging_dir = os.path.join(args.output_dir, args.logging_dir) + + # accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + # log_with=args.report_to, + # logging_dir=logging_dir, + # project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # text_encoder = CLIPTextModel.from_pretrained( + # args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + # ) + vae = AutoencoderKL.from_pretrained( + "/fsx_laion/alvin/pretrain/sd-vae-ft-mse" + ) + + vae_path = ( + args.pretrained_model_name_or_path2 + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae2 = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + ) + + from diffusers.models.unet_2d_condition_multi_branch_downup import UNet2DConditionModel + unet_t2mn = UNet2DConditionModel.from_pretrained(args.t2mn_path, subfolder="unet_ema") + unet_t2mn.requires_grad_(False) + + unet = diffusers.UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path2, subfolder="unet", revision=args.revision, use_auth_token=True + ) + + # if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + vae2.requires_grad_(False) + unet.requires_grad_(False) + unet_t2mn.requires_grad_(False) + # text_encoder.requires_grad_(False) + controlnet.requires_grad_(False) + + unet.eval() + unet_t2mn.eval() + controlnet.eval() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + unet_t2mn.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + tf = transforms.Compose( + [transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(512), + ] + ) + + from mmpose.apis import MMPoseInferencer + # import mmcv + + body_inferencer = MMPoseInferencer( + pose2d='/fsx_laion/alvin/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192.py', + pose2d_weights='/fsx_laion/alvin/pretrain/ViTPose/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192-ffd48c05_20230314.pth', + scope="mmpose" + # det_model='/fsx_laion/alvin/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py', + # det_weights="/fsx_laion/alvin/pretrain/ViTPose/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" + ) + + input_img = PIL.Image.open(args.ref_path) + input_img = tf(input_img) + + image = np.array(input_img.convert("RGB")) + img_list = [image] + result_generator = body_inferencer(img_list, return_datasample=True) + result = next(result_generator) + + # output[img_id]["new_body_bbox"] = result['predictions'][0].pred_instances.bboxes.tolist() + # output[img_id]["new_body_bbox_score"] = result['predictions'][0].pred_instances.bbox_scores.tolist() + # output[img_id]["new_body_kp"] = result['predictions'][0].pred_instances.keypoints.tolist() + # output[img_id]["new_body_kp_score"] = result['predictions'][0].pred_instances.keypoint_scores.tolist() + + kp_coord = result['predictions'][0].pred_instances.keypoints + kp_coord_1024 = kp_coord * 2. + kp_conf = result['predictions'][0].pred_instances.keypoint_scores + kp = np.concatenate([kp_coord, kp_conf[..., np.newaxis]], axis=-1) + kp_1024 = np.concatenate([kp_coord_1024, kp_conf[..., np.newaxis]], axis=-1) + + whole_draw = draw_humansd_skeleton( + image=None, + pose=kp, + height=512, + width=512, + humansd_skeleton_width=10, + ) + whole_image = Image.fromarray(whole_draw) + + whole_draw_1024 = draw_humansd_skeleton( + # image=np.array(sample["image"]), + image=None, + pose=kp_1024, + height=1024, + width=1024, + humansd_skeleton_width=20, + ) + whole_image_1024 = Image.fromarray(whole_draw_1024) + + preprocess = transforms.Compose( + [ + transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + preprocess_1024 = transforms.Compose( + [ + transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + whole = preprocess(whole_image) + whole_1024 = preprocess_1024(whole_image_1024) + + # dataset = CustomDataset(args) + # test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) + + # lr_scheduler = get_scheduler( + # args.lr_scheduler, + # optimizer=optimizer, + # num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + # num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + # ) + + # # Prepare everything with our `accelerator`. + unet, unet_t2mn, controlnet = accelerator.prepare( + unet, unet_t2mn, controlnet + ) + + # Move text_encode and vae to gpu and cast to weight_dtype + # text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + vae2.to(accelerator.device, dtype=weight_dtype) + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet_t2mn), + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + if args.flaw: + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_betas_zero_snr=True, + timestep_spacing="trailing") + pipeline.scheduler.config.rescale_betas_zero_snr = True + pipeline.scheduler.config['rescale_betas_zero_snr'] = True + pipeline.scheduler.config.timestep_spacing = "trailing" + pipeline.scheduler.config['timestep_spacing'] = "trailing" + else: + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler.set_timesteps(args.step_num1) + + pipeline.scheduler.config.prediction_type = args.prediction_type + pipeline.scheduler.config['prediction_type'] = args.prediction_type + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=False) + + controlnet = accelerator.unwrap_model(controlnet) + + pipeline2 = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path2, + vae=vae2, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + controlnet=controlnet, + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + # pipeline2.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config) + pipeline2.scheduler = DDPMScheduler.from_config(pipeline2.scheduler.config) + pipeline2.scheduler.config.prediction_type = args.prediction_type2 + pipeline2.scheduler.config['prediction_type'] = args.prediction_type2 + pipeline2 = pipeline2.to(accelerator.device) + pipeline2.set_progress_bar_config(disable=False) + + refiner = DiffusionPipeline.from_pretrained( + "/fsx_laion/alvin/pretrain/stable-diffusion-xl-refiner-1.0", + text_encoder_2=pipeline2.text_encoder_2, + vae=pipeline2.vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + ) + # refiner.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config) + refiner.scheduler = DDPMScheduler.from_config(refiner.scheduler.config) + refiner.scheduler.config.prediction_type = args.prediction_type2 + refiner.scheduler.config['prediction_type'] = args.prediction_type2 + refiner = refiner.to(accelerator.device) + refiner.set_progress_bar_config(disable=False) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + pipeline2.enable_xformers_memory_efficient_attention() + refiner.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + # step1 = args.t2mn_path.split('/')[-1].split("-")[1] + # step2 = args.controlnet_model_name_or_path.split('/')[-1].split("-")[1] + + os.makedirs(args.inference_folder_name, exist_ok=True) + # save_path_body = os.path.join(save_path, 'body') + # save_path_depth = os.path.join(save_path, 'depth') + # save_path_normal = os.path.join(save_path, 'normal') + # save_path_rgb1 = os.path.join(save_path, 'rgb1') + # save_path_rgb2 = os.path.join(save_path, 'rgb2') + # os.makedirs(save_path_body, exist_ok=True) + # os.makedirs(save_path_depth, exist_ok=True) + # os.makedirs(save_path_normal, exist_ok=True) + # os.makedirs(save_path_rgb1, exist_ok=True) + # os.makedirs(save_path_rgb2, exist_ok=True) + + batch = {} + whole = whole.to(unet.device) + whole_1024 = whole_1024.to(unet.device) + batch["whole"] = whole.unsqueeze(0) + batch["body"] = whole_1024.unsqueeze(0) + + with torch.autocast("cuda"): + output = pipeline( + args.prompt, + height=args.resolution, + width=args.resolution, + num_inference_steps=args.step_num1, + generator=generator, + batch=batch, + args=args, + original_size=(args.size, args.size), + guidance_rescale=0.7 if args.flaw else 0., + negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + ) + + image = output.images[0] + image.save(os.path.join(args.inference_folder_name, "rgb.png")) + midas_depth_image = output.midas_depth_image[0] + midas_depth_image.save(os.path.join(args.inference_folder_name, "depth.png")) + normal_image = output.normal_image[0] + normal_image.save(os.path.join(args.inference_folder_name, "normal.png")) + + resize_transform = transforms.Compose( + [ + transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + # transforms.Normalize([0.5], [0.5]), + ] + ) + normalize_transform = transforms.Normalize([0.5], [0.5]) + # midas_depth_tensor = 2 * (transforms.ToTensor()(midas_depth_image)) - 1 + midas_depth_tensor = resize_transform(midas_depth_image) + # print(midas_depth_tensor.shape) + midas_depth_tensor = torch.mean(midas_depth_tensor, dim=0) + # print(midas_depth_tensor.shape) + depth_min = torch.amin(midas_depth_tensor, dim=[0, 1], keepdim=True) + depth_max = torch.amax(midas_depth_tensor, dim=[0, 1], keepdim=True) + midas_depth_tensor = (midas_depth_tensor - depth_min) / (depth_max - depth_min) + midas_depth_tensor = normalize_transform(midas_depth_tensor.unsqueeze(0).repeat(3, 1, 1)) + batch["midas_depth"] = midas_depth_tensor.unsqueeze(0).to(unet.device) + + # normal_tensor = 2 * (transforms.ToTensor()(normal_image)) - 1 + normal_tensor = resize_transform(normal_image) + normal_tensor = normal_tensor.clamp(min=0, max=1) + normal_tensor = normalize_transform(normal_tensor) + batch["normal"] = normal_tensor.unsqueeze(0).to(unet.device) + + body_denormalize = (batch["body"] + 1) / 2.0 + body_numpy = body_denormalize.cpu().permute(0, 2, 3, 1).float().numpy()[0] + body_numpy = (body_numpy * 255).round().astype("uint8") + body_pil = Image.fromarray(body_numpy) + body_pil.save(os.path.join(args.inference_folder_name, "body.png")) + # batch["body"] = batch["body"][0].unsqueeze(0) + + # batch["whole"] = batch["whole_1024"] + + controlnet_image = [] + for key in ['depth', 'midas_depth', 'normal', 'canny', 'body', 'face', 'hand', 'whole']: + if key in args.cond_type: + controlnet_image.append(batch[key][0]) + + n_steps = args.step_num2 + high_noise_frac = 0.8 + + with torch.autocast("cuda"): + output = pipeline2( + args.prompt, + image=controlnet_image, + height=1024, + width=1024, + num_inference_steps=n_steps, + denoising_end=high_noise_frac, + output_type="latent", + generator=generator, + original_size=(args.size, args.size), + negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + ) + + # image = output.images[0] + # image.save(os.path.join(save_path_rgb2, f"{int(id[i_batch]):012d}.jpg")) + + image = output.images + image = refiner( + args.prompt, + # height=1024, + # width=1024, + num_inference_steps=n_steps, + denoising_start=high_noise_frac, + image=image, + # guidance_scale=args.cfg, + generator=generator, + negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + ).images[0] + image.save(os.path.join(args.inference_folder_name, "rgb2.png")) + + +if __name__ == "__main__": + main() diff --git a/aagenerator.py b/aagenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd96a8970b34d097127db664ab244714d7e5a0b --- /dev/null +++ b/aagenerator.py @@ -0,0 +1,981 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import AutoTokenizer, PretrainedConfig +from transformers.utils import ContextManagers +from PIL import Image +import PIL +from PIL import ImageFile + +import diffusers +from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from openclip.training.data import get_wds_dataset, get_wds_dataset_cond +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, \ + DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, \ + PNDMScheduler, LMSDiscreteScheduler, UniPCMultistepScheduler + +from models.embedder import Embedder +from pipelines.pipeline_stable_diffusion_mb_downup import StableDiffusionPipeline +from collections import OrderedDict +import boto3 +from diffusers.models.controlnet_composer import ControlNetModel +# from pipelines.pipeline_controlnet_composer import StableDiffusionControlNetPipeline +from pipelines.pipeline_controlnet_composer_sdxl import StableDiffusionXLControlNetPipeline +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info +import json +import cv2 +import seaborn as sns +from mmpose.apis import MMPoseInferencer +from diffusers.models.unet_2d_condition_multi_branch_downup import UNet2DConditionModel + +if is_wandb_available(): + import wandb + +logger = get_logger(__name__, log_level="INFO") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0.dev0") + + +def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10): + humansd_skeleton = [ + [0, 0, 1], + [1, 0, 2], + [2, 1, 3], + [3, 2, 4], + [4, 3, 5], + [5, 4, 6], + [6, 5, 7], + [7, 6, 8], + [8, 7, 9], + [9, 8, 10], + [10, 5, 11], + [11, 6, 12], + [12, 11, 13], + [13, 12, 14], + [14, 13, 15], + [15, 14, 16], + ] + # humansd_skeleton_width=10 + humansd_color = sns.color_palette("hls", len(humansd_skeleton)) + + def plot_kpts(img_draw, kpts, color, edgs, width): + for idx, kpta, kptb in edgs: + if kpts[kpta, 2] > mmpose_detection_thresh and \ + kpts[kptb, 2] > mmpose_detection_thresh: + line_color = tuple([int(255 * color_i) for color_i in color[idx]]) + + cv2.line(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), (int(kpts[kptb, 0]), int(kpts[kptb, 1])), + line_color, width) + cv2.circle(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), width // 2, line_color, -1) + cv2.circle(img_draw, (int(kpts[kptb, 0]), int(kpts[kptb, 1])), width // 2, line_color, -1) + + if image is None: + pose_image = np.zeros((height, width, 3), dtype=np.uint8) + else: + pose_image = np.array(image, dtype=np.uint8) + for person_i in range(len(pose)): + if np.sum(pose[person_i]) > 0: + plot_kpts(pose_image, pose[person_i], humansd_color, humansd_skeleton, humansd_skeleton_width) + + return pose_image + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + ################################### newly added args ################################### + parser.add_argument("--ref_path", type=str, default="/data_laion/alvin/Dataset/evaluation/debug/42361.png") + parser.add_argument("--prompt", type=str, default="A person riding skis down a snow covered slope.") + parser.add_argument("--t2mn_path", type=str, + default="/data_laion/alvin/sd4human/ckpts/a-ranstart-body-sdv20-v-nd-flaw-avg-copy1-glc-resume288k-512-ft1024/checkpoint-388000") + parser.add_argument("--controlnet_model_name_or_path", type=str, + default="/data_laion/alvin/sd4human/ckpts/ctrl-sdxl10-eps-glc-composer-bmn-sum-1024/checkpoint-91000") + parser.add_argument('--step_num1', default=50, type=int) + parser.add_argument('--step_num2', default=50, type=int) + parser.add_argument('--size', default=2048, type=int) + parser.add_argument("--pretrained_vae_model_name_or_path", type=str, + default='/fsx_laion/alvin/pretrain/sdxl-vae-fp16-fix') + parser.add_argument('--normalize_dist', default=True, action="store_false") + parser.add_argument('--change_whole_to_body', default=True, action="store_false") + parser.add_argument('--off_wa', default=True, action="store_false") + parser.add_argument('--flaw', default=True, action="store_false") + parser.add_argument("--enable_xformers_memory_efficient_attention", default=True, action="store_false", + help="Whether or not to use xformers.") + # statistics for three datasets, laion+coyo+getty + parser.add_argument("--rgb_mean", type=float, default=0.14654) + parser.add_argument("--rgb_std", type=float, default=1.03744) + # parser.add_argument("--whole_mean", type=float, default=0.14713) + # parser.add_argument("--whole_std", type=float, default=0.96812) + parser.add_argument("--whole_mean", type=float, default=-0.2599426086956522) + parser.add_argument("--whole_std", type=float, default=1.3836632689065582) + parser.add_argument("--body_mean", type=float, default=-0.2481) + parser.add_argument("--body_std", type=float, default=1.45647) + parser.add_argument("--depth_mean", type=float, default=0.21360) + parser.add_argument("--depth_std", type=float, default=1.20629) + parser.add_argument("--normal_mean", type=float, default=0.60303) + parser.add_argument("--normal_std", type=float, default=0.91429) + + # # statistics for two datasetsm laion+coyo + # parser.add_argument("--rgb_mean", type=float, default=0.144028) + # parser.add_argument("--rgb_std", type=float, default=1.0420677550094796) + # parser.add_argument("--whole_mean", type=float, default=-0.2598586666666667) + # parser.add_argument("--whole_std", type=float, default=1.3824869261991977) + # parser.add_argument("--body_mean", type=float, default=-0.2481) + # parser.add_argument("--body_std", type=float, default=1.45647) + # parser.add_argument("--depth_mean", type=float, default=0.22104533333333334) + # parser.add_argument("--depth_std", type=float, default=1.2044201368629092) + # parser.add_argument("--normal_mean", type=float, default=0.6173293333333333) + # parser.add_argument("--normal_std", type=float, default=0.9108628719489077) + parser.add_argument('--start', default=0, type=int) + parser.add_argument('--end', default=8236, type=int) + parser.add_argument("--pretrained_model_name_or_path", type=str, + default='/fsx_laion/alvin/pretrain/stable-diffusion-2-base') + parser.add_argument("--pretrained_model_name_or_path2", type=str, + default='/fsx_laion/alvin/pretrain/stable-diffusion-xl-base-1.0') + parser.add_argument('--prediction_type', type=str, default='v_prediction', + choices=['epsilon', 'v_prediction', 'target'], help='Select a mode') + parser.add_argument('--prediction_type2', type=str, default='epsilon', + choices=['epsilon', 'v_prediction', 'target'], help='Select a mode') + parser.add_argument("--cond_num", type=int, default=3) + parser.add_argument('--fusion', type=str, default="sum") + parser.add_argument("--validation_steps", type=int, default=500, ) + parser.add_argument("--test_data_dir", nargs='+', type=str, default=None, ) + parser.add_argument('--filter_lowres', default=False, action="store_true") + parser.add_argument("--filter_res", type=int) + parser.add_argument('--noisy_cond', type=str, default=[], nargs="+", help='add which types of conditions') + parser.add_argument("--output_dir2", type=str, default="sd-model-finetuned") + parser.add_argument('--cond_reshape2', type=str, choices=['resize', 'vae', 'learn_conv'], + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--inference_folder_name2', type=str, + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--cond_inject2', type=str, choices=['concat', 'spade', 'sum'], + help='how to inject the spatial condition') + parser.add_argument('--cond_type2', type=str, default=[], nargs="+", help='add which types of conditions') + parser.add_argument('--cond_type_test2', type=str, default=None, nargs="+", help='add which types of conditions') + parser.add_argument("--resume_from_checkpoint2", type=str, default=None) + parser.add_argument('--pred_cond2', default=False, action="store_true") + parser.add_argument('--save_cond2', default=False, action="store_true") + parser.add_argument('--inference_folder_name', type=str, + default="/data_laion/yli12/code_new/ControlNet/output-images", + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--grid_dnc', default=False, action="store_true") + parser.add_argument('--pred_cond', default=False, action="store_true") + parser.add_argument('--save_cond', default=False, action="store_true") + parser.add_argument('--cond_reshape', type=str, choices=['resize', 'vae', 'learn_conv'], + help='how to reshape the spatial condition to the same shape as the latent space size') + parser.add_argument('--cond_inject', type=str, choices=['concat', 'spade', 'sum'], + help='how to inject the spatial condition') + parser.add_argument('--cond_type', type=str, default=["body", "midas_depth", "normal"], nargs="+", + help='add which types of conditions') + parser.add_argument('--cond_type_test', type=str, default=None, nargs="+", help='add which types of conditions') + parser.add_argument("--embedder_channel", default=4, type=int, help="channel number.") + ################################### newly added args ################################### + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + nargs='+', + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=7, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.self.accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to self.accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # # Sanity checks + # if args.dataset_name is None and args.train_data_dir is None: + # raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +class Generator: + def __init__(self, args=parse_args()): + super().__init__() + self.args = args + if args.change_whole_to_body: + args.whole_mean = args.body_mean + args.whole_std = args.body_std + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + # logging_dir = os.path.join(args.output_dir, args.logging_dir) + + # accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit, logging_dir=logging_dir) + + self.accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + # log_with=args.report_to, + # logging_dir=logging_dir, + # project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(self.accelerator.state, main_process_only=False) + if self.accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if self.accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # text_encoder = CLIPTextModel.from_pretrained( + # args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + # ) + self.vae = AutoencoderKL.from_pretrained( + "/fsx_laion/alvin/pretrain/sd-vae-ft-mse" + ) + + vae_path = ( + args.pretrained_model_name_or_path2 + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + self.vae2 = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + ) + + self.unet_t2mn = UNet2DConditionModel.from_pretrained(args.t2mn_path, subfolder="unet_ema") + self.unet_t2mn.requires_grad_(False) + + self.unet = diffusers.UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path2, subfolder="unet", revision=args.revision, use_auth_token=True + ) + + # if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + self.controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + self.accelerator.register_save_state_pre_hook(save_model_hook) + self.accelerator.register_load_state_pre_hook(load_model_hook) + + self.vae.requires_grad_(False) + self.vae2.requires_grad_(False) + self.unet.requires_grad_(False) + self.unet_t2mn.requires_grad_(False) + # text_encoder.requires_grad_(False) + self.controlnet.requires_grad_(False) + + self.unet.eval() + self.unet_t2mn.eval() + self.controlnet.eval() + + if args.gradient_checkpointing: + self.unet.enable_gradient_checkpointing() + self.unet_t2mn.enable_gradient_checkpointing() + self.controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if self.accelerator.unwrap_model(self.controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {self.accelerator.unwrap_model(self.controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + self.optimizer_class = bnb.optim.AdamW8bit + else: + self.optimizer_class = torch.optim.AdamW + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if self.accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + self.tf = transforms.Compose( + [transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(512), + ] + ) + + self.body_inferencer = MMPoseInferencer( + pose2d='/fsx_laion/alvin/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192.py', + pose2d_weights='/fsx_laion/alvin/pretrain/ViTPose/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192-ffd48c05_20230314.pth', + scope="mmpose" + # det_model='/fsx_laion/alvin/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py', + # det_weights="/fsx_laion/alvin/pretrain/ViTPose/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" + ) + + # # Prepare everything with our `accelerator`. + self.unet, self.unet_t2mn, self.controlnet = self.accelerator.prepare( + self.unet, self.unet_t2mn, self.controlnet + ) + + # Move text_encode and vae to gpu and cast to weight_dtype + # text_encoder.to(self.accelerator.device, dtype=weight_dtype) + self.vae.to(self.accelerator.device, dtype=weight_dtype) + self.vae2.to(self.accelerator.device, dtype=weight_dtype) + + self.pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=self.vae, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + unet=self.accelerator.unwrap_model(self.unet_t2mn), + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + if args.flaw: + self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config, + rescale_betas_zero_snr=True, + timestep_spacing="trailing") + self.pipeline.scheduler.config.rescale_betas_zero_snr = True + self.pipeline.scheduler.config['rescale_betas_zero_snr'] = True + self.pipeline.scheduler.config.timestep_spacing = "trailing" + self.pipeline.scheduler.config['timestep_spacing'] = "trailing" + else: + self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config) + self.pipeline.scheduler.set_timesteps(args.step_num1) + + self.pipeline.scheduler.config.prediction_type = args.prediction_type + self.pipeline.scheduler.config['prediction_type'] = args.prediction_type + + self.pipeline = self.pipeline.to(self.accelerator.device) + self.pipeline.set_progress_bar_config(disable=False) + + self.controlnet = self.accelerator.unwrap_model(self.controlnet) + + self.pipeline2 = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path2, + vae=self.vae2, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + unet=self.accelerator.unwrap_model(self.unet), + controlnet=self.controlnet, + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + # pipeline2.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config) + self.pipeline2.scheduler = DDPMScheduler.from_config(self.pipeline2.scheduler.config) + self.pipeline2.scheduler.config.prediction_type = args.prediction_type2 + self.pipeline2.scheduler.config['prediction_type'] = args.prediction_type2 + self.pipeline2 = self.pipeline2.to(self.accelerator.device) + self.pipeline2.set_progress_bar_config(disable=False) + + self.refiner = DiffusionPipeline.from_pretrained( + "/fsx_laion/alvin/pretrain/stable-diffusion-xl-refiner-1.0", + text_encoder_2=self.pipeline2.text_encoder_2, + vae=self.pipeline2.vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + ) + # refiner.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config) + self.refiner.scheduler = DDPMScheduler.from_config(self.refiner.scheduler.config) + self.refiner.scheduler.config.prediction_type = args.prediction_type2 + self.refiner.scheduler.config['prediction_type'] = args.prediction_type2 + self.refiner = self.refiner.to(self.accelerator.device) + self.refiner.set_progress_bar_config(disable=False) + + if args.enable_xformers_memory_efficient_attention: + self.pipeline.enable_xformers_memory_efficient_attention() + self.pipeline2.enable_xformers_memory_efficient_attention() + self.refiner.enable_xformers_memory_efficient_attention() + + if args.seed is None: + self.generator = None + else: + self.generator = torch.Generator(device=self.accelerator.device).manual_seed(args.seed) + + # step1 = args.t2mn_path.split('/')[-1].split("-")[1] + # step2 = args.controlnet_model_name_or_path.split('/')[-1].split("-")[1] + + self.preprocess = transforms.Compose( + [ + transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.preprocess_1024 = transforms.Compose( + [ + transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.resize_transform = transforms.Compose( + [ + transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + # transforms.Normalize([0.5], [0.5]), + ] + ) + self.normalize_transform = transforms.Normalize([0.5], [0.5]) + + os.makedirs(args.inference_folder_name, exist_ok=True) + + def run(self, input_img=None, prompt=None, steps=None): + if input_img is None: + input_img = PIL.Image.open(self.args.ref_path) + else: + input_img = PIL.Image.fromarray(input_img) + input_img = self.tf(input_img) + image = np.array(input_img.convert("RGB")) + img_list = [image] + result_generator = self.body_inferencer(img_list, return_datasample=True) + result = next(result_generator) + + kp_coord = result['predictions'][0].pred_instances.keypoints + kp_coord_1024 = kp_coord * 2. + kp_conf = result['predictions'][0].pred_instances.keypoint_scores + kp = np.concatenate([kp_coord, kp_conf[..., np.newaxis]], axis=-1) + kp_1024 = np.concatenate([kp_coord_1024, kp_conf[..., np.newaxis]], axis=-1) + + whole_draw = draw_humansd_skeleton( + image=None, + pose=kp, + height=512, + width=512, + humansd_skeleton_width=10, + ) + whole_image = Image.fromarray(whole_draw) + + whole_draw_1024 = draw_humansd_skeleton( + # image=np.array(sample["image"]), + image=None, + pose=kp_1024, + height=1024, + width=1024, + humansd_skeleton_width=20, + ) + whole_image_1024 = Image.fromarray(whole_draw_1024) + + whole = self.preprocess(whole_image) + whole_1024 = self.preprocess_1024(whole_image_1024) + + batch = {} + whole = whole.to(self.unet.device) + whole_1024 = whole_1024.to(self.unet.device) + batch["whole"] = whole.unsqueeze(0) + batch["body"] = whole_1024.unsqueeze(0) + + body_denormalize = (batch["body"] + 1) / 2.0 + body_numpy = body_denormalize.cpu().permute(0, 2, 3, 1).float().numpy()[0] + body_numpy = (body_numpy * 255).round().astype("uint8") + + with torch.autocast("cuda"): + output = self.pipeline( + prompt or self.args.prompt, + height=self.args.resolution, + width=self.args.resolution, + num_inference_steps=steps or self.args.step_num1, + generator=self.generator, + batch=batch, + args=self.args, + original_size=(self.args.size, self.args.size), + guidance_rescale=0.7 if self.args.flaw else 0., + negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + ) + + rgb = output.images[0] + midas_depth_image = output.midas_depth_image[0] + normal_image = output.normal_image[0] + + # midas_depth_tensor = 2 * (transforms.ToTensor()(midas_depth_image)) - 1 + midas_depth_tensor = self.resize_transform(midas_depth_image) + # print(midas_depth_tensor.shape) + midas_depth_tensor = torch.mean(midas_depth_tensor, dim=0) + # print(midas_depth_tensor.shape) + depth_min = torch.amin(midas_depth_tensor, dim=[0, 1], keepdim=True) + depth_max = torch.amax(midas_depth_tensor, dim=[0, 1], keepdim=True) + midas_depth_tensor = (midas_depth_tensor - depth_min) / (depth_max - depth_min) + midas_depth_tensor = self.normalize_transform(midas_depth_tensor.unsqueeze(0).repeat(3, 1, 1)) + batch["midas_depth"] = midas_depth_tensor.unsqueeze(0).to(self.unet.device) + + # normal_tensor = 2 * (transforms.ToTensor()(normal_image)) - 1 + normal_tensor = self.resize_transform(normal_image) + normal_tensor = normal_tensor.clamp(min=0, max=1) + normal_tensor = self.normalize_transform(normal_tensor) + batch["normal"] = normal_tensor.unsqueeze(0).to(self.unet.device) + + # batch["body"] = batch["body"][0].unsqueeze(0) + # batch["whole"] = batch["whole_1024"] + + controlnet_image = [] + for key in ['depth', 'midas_depth', 'normal', 'canny', 'body', 'face', 'hand', 'whole']: + if key in self.args.cond_type: + controlnet_image.append(batch[key][0]) + + n_steps = steps or self.args.step_num2 + high_noise_frac = 0.8 + + with torch.autocast("cuda"): + output2 = self.pipeline2( + prompt or self.args.prompt, + image=controlnet_image, + height=1024, + width=1024, + num_inference_steps=n_steps, + denoising_end=high_noise_frac, + output_type="latent", + generator=self.generator, + original_size=(self.args.size, self.args.size), + negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + ) + + # image = output.images[0] + # image.save(os.path.join(save_path_rgb2, f"{int(id[i_batch]):012d}.jpg")) + + image = output2.images + rgb2 = self.refiner( + prompt or self.args.prompt, + # height=1024, + # width=1024, + num_inference_steps=n_steps, + denoising_start=high_noise_frac, + image=image, + # guidance_scale=self.args.cfg, + generator=self.generator, + negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + ).images[0] + + return body_numpy, rgb, midas_depth_image, normal_image, rgb2 + + def save_images(self, input_img=None, prompt=None): + body_numpy, rgb, midas_depth_image, normal_image, rgb2 = self.run(input_img=input_img, prompt=prompt) + body_pil = Image.fromarray(body_numpy) + body_pil.save(os.path.join(self.args.inference_folder_name, "body.png")) + + rgb.save(os.path.join(self.args.inference_folder_name, "rgb.png")) + midas_depth_image.save(os.path.join(self.args.inference_folder_name, "depth.png")) + normal_image.save(os.path.join(self.args.inference_folder_name, "normal.png")) + + rgb2.save(os.path.join(self.args.inference_folder_name, "rgb2.png")) + + +if __name__ == "__main__": + gen = Generator() + gen.save_images() diff --git a/annotator/__pycache__/util.cpython-38.pyc b/annotator/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..645da9c23f62d4922b09d826528259552daa38d5 Binary files /dev/null and b/annotator/__pycache__/util.cpython-38.pyc differ diff --git a/annotator/canny/__init__.py b/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/annotator/ckpts/body_pose_model.pth b/annotator/ckpts/body_pose_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..9acb77e68f31906a8875f1daef2f3f7ef94acb1e --- /dev/null +++ b/annotator/ckpts/body_pose_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746 +size 209267595 diff --git a/annotator/ckpts/ckpts.txt b/annotator/ckpts/ckpts.txt new file mode 100644 index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b --- /dev/null +++ b/annotator/ckpts/ckpts.txt @@ -0,0 +1 @@ +Weights here. \ No newline at end of file diff --git a/annotator/ckpts/hand_pose_model.pth b/annotator/ckpts/hand_pose_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..f23ccf3413cc8ac8581a82338a3037bc10d573f0 --- /dev/null +++ b/annotator/ckpts/hand_pose_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b76b00d1750901abd07b9f9d8c98cc3385b8fe834a26d4b4f0aad439e75fc600 +size 147341049 diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a8fc712fba02b033dea13bfe33204b8d3c9139 --- /dev/null +++ b/annotator/hed/__init__.py @@ -0,0 +1,96 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np + +from einops import rearrange +from annotator.util import annotator_ckpts_path + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector: + def __init__(self): + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth" + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/annotator/midas/LICENSE b/annotator/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/annotator/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d --- /dev/null +++ b/annotator/midas/__init__.py @@ -0,0 +1,42 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/annotator/midas/api.py b/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab9f15bf96bbaffcee0e3e29fc9d3979d6c32e8 --- /dev/null +++ b/annotator/midas/api.py @@ -0,0 +1,169 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/annotator/midas/midas/__init__.py b/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/annotator/midas/midas/base_model.py b/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/annotator/midas/midas/blocks.py b/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/annotator/midas/midas/dpt_depth.py b/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/annotator/midas/midas/midas_net.py b/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/annotator/midas/midas/midas_net_custom.py b/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/annotator/midas/midas/transforms.py b/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/annotator/midas/midas/vit.py b/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/annotator/midas/utils.py b/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/annotator/mlsd/LICENSE b/annotator/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/annotator/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1860702df6150c5a93c9bb6bf34906a77048c7c --- /dev/null +++ b/annotator/mlsd/__init__.py @@ -0,0 +1,43 @@ +# MLSD Line Detection +# From https://github.com/navervision/mlsd +# Apache-2.0 license + +import cv2 +import numpy as np +import torch +import os + +from einops import rearrange +from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + +from annotator.util import annotator_ckpts_path + + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth" + + +class MLSDdetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + self.model = model.cuda().eval() + + def __call__(self, input_image, thr_v, thr_d): + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + return img_output[:, :, 0] diff --git a/annotator/mlsd/models/mbv2_mlsd_large.py b/annotator/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/annotator/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/annotator/mlsd/models/mbv2_mlsd_tiny.py b/annotator/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/annotator/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3cf9420a33a4abae27c48ac4b90938c7d63cc3 --- /dev/null +++ b/annotator/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/annotator/openpose/LICENSE b/annotator/openpose/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd --- /dev/null +++ b/annotator/openpose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/annotator/openpose/__init__.py b/annotator/openpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92e530fd6913a92b1e624d3e334252bcfdba902f --- /dev/null +++ b/annotator/openpose/__init__.py @@ -0,0 +1,49 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .body import Body +from .hand import Hand +from annotator.util import annotator_ckpts_path + + +body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth" +hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth" + + +class OpenposeDetector: + def __init__(self): + body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth") + hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth") + + if not os.path.exists(hand_modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(body_model_path, model_dir=annotator_ckpts_path) + load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path) + + self.body_estimation = Body(body_modelpath) + self.hand_estimation = Hand(hand_modelpath) + + def __call__(self, oriImg, hand=False): + oriImg = oriImg[:, :, ::-1].copy() + with torch.no_grad(): + candidate, subset = self.body_estimation(oriImg) + canvas = np.zeros_like(oriImg) + canvas = util.draw_bodypose(canvas, candidate, subset) + if hand: + hands_list = util.handDetect(candidate, subset, oriImg) + all_hand_peaks = [] + for x, y, w, is_left in hands_list: + peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]) + peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x) + peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y) + all_hand_peaks.append(peaks) + canvas = util.draw_handpose(canvas, all_hand_peaks) + return canvas, dict(candidate=candidate.tolist(), subset=subset.tolist()) diff --git a/annotator/openpose/__pycache__/__init__.cpython-38.pyc b/annotator/openpose/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c738a75676c80a4b350f2968045433638b765e6c Binary files /dev/null and b/annotator/openpose/__pycache__/__init__.cpython-38.pyc differ diff --git a/annotator/openpose/__pycache__/body.cpython-38.pyc b/annotator/openpose/__pycache__/body.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e978aabae0eca387d7d5cec107df39beb555210f Binary files /dev/null and b/annotator/openpose/__pycache__/body.cpython-38.pyc differ diff --git a/annotator/openpose/__pycache__/hand.cpython-38.pyc b/annotator/openpose/__pycache__/hand.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c65ade4e28747a7c8a823cab3b1e84f5ecb3b87a Binary files /dev/null and b/annotator/openpose/__pycache__/hand.cpython-38.pyc differ diff --git a/annotator/openpose/__pycache__/model.cpython-38.pyc b/annotator/openpose/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82eec9bf06f2b7136945421ad13068197e2da22b Binary files /dev/null and b/annotator/openpose/__pycache__/model.cpython-38.pyc differ diff --git a/annotator/openpose/__pycache__/util.cpython-38.pyc b/annotator/openpose/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff4f6a77483f0269317fed94fac4c8ba99420f66 Binary files /dev/null and b/annotator/openpose/__pycache__/util.cpython-38.pyc differ diff --git a/annotator/openpose/body.py b/annotator/openpose/body.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3cf7a388b4ac81004524e64125e383bdd455bd --- /dev/null +++ b/annotator/openpose/body.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np +import math +import time +from scipy.ndimage.filters import gaussian_filter +import matplotlib.pyplot as plt +import matplotlib +import torch +from torchvision import transforms + +from . import util +from .model import bodypose_model + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + if torch.cuda.is_available(): + self.model = self.model.cuda() + print('cuda') + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += + paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + +if __name__ == "__main__": + body_estimation = Body('../model/body_pose_model.pth') + + test_image = '../images/ski.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + candidate, subset = body_estimation(oriImg) + canvas = util.draw_bodypose(oriImg, candidate, subset) + plt.imshow(canvas[:, :, [2, 1, 0]]) + plt.show() diff --git a/annotator/openpose/hand.py b/annotator/openpose/hand.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0bf17165ad7eb225332b51f4a2aa16718664b2 --- /dev/null +++ b/annotator/openpose/hand.py @@ -0,0 +1,86 @@ +import cv2 +import json +import numpy as np +import math +import time +from scipy.ndimage.filters import gaussian_filter +import matplotlib.pyplot as plt +import matplotlib +import torch +from skimage.measure import label + +from .model import handpose_model +from . import util + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + if torch.cuda.is_available(): + self.model = self.model.cuda() + print('cuda') + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22)) + # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + output = self.model(data).cpu().numpy() + # output = self.model(data).numpy()q + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + # 全部小于阈值 + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) \ No newline at end of file diff --git a/annotator/openpose/model.py b/annotator/openpose/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5dfc80de827a17beccb9b0f3f7588545be78c9de --- /dev/null +++ b/annotator/openpose/model.py @@ -0,0 +1,219 @@ +import torch +from collections import OrderedDict + +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 + + diff --git a/annotator/openpose/util.py b/annotator/openpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6f91ae0e65abaf0cbd62d803f56498991141e61b --- /dev/null +++ b/annotator/openpose/util.py @@ -0,0 +1,164 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + +# transfer caffe model to pytorch which will match the layer name +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + +# draw the body keypoint and lims +def draw_bodypose(canvas, candidate, subset): + stickwidth = 4 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + cur_canvas = canvas.copy() + Y = candidate[index.astype(int), 0] + X = candidate[index.astype(int), 1] + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]]) + # plt.imshow(canvas[:, :, [2, 1, 0]]) + return canvas + + +# image drawed by opencv is not good. +def draw_handpose(canvas, all_hand_peaks, show_number=False): + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + for ie, e in enumerate(edges): + if np.sum(np.all(peaks[e], axis=1)==0)==0: + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie/float(len(edges)), 1.0, 1.0])*255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + if show_number: + cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA) + return canvas + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/annotator/uniformer/LICENSE b/annotator/uniformer/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c38dc639e6e238fbf59608f80b3a6ff1928ac429 --- /dev/null +++ b/annotator/uniformer/LICENSE @@ -0,0 +1,203 @@ +Copyright 2022 SenseTime X-Lab. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 SenseTime X-Lab. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3364d40997447a4ec15ca7a525a4d0e92ab211bd --- /dev/null +++ b/annotator/uniformer/__init__.py @@ -0,0 +1,27 @@ +# Uniformer +# From https://github.com/Sense-X/UniFormer +# # Apache-2.0 license + +import os + +from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot +from annotator.uniformer.mmseg.core.evaluation import get_palette +from annotator.util import annotator_ckpts_path + + +checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth" + + +class UniformerDetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path) + config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") + self.model = init_segmentor(config_file, modelpath).cuda() + + def __call__(self, img): + result = inference_segmentor(self.model, img) + res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1) + return res_img diff --git a/annotator/uniformer/configs/_base_/datasets/ade20k.py b/annotator/uniformer/configs/_base_/datasets/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/ade20k.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'ADE20KDataset' +data_root = 'data/ade/ADEChallengeData2016' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/chase_db1.py b/annotator/uniformer/configs/_base_/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/chase_db1.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'ChaseDB1Dataset' +data_root = 'data/CHASE_DB1' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (960, 999) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes.py b/annotator/uniformer/configs/_base_/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/cityscapes.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = 'data/cityscapes/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/train', + ann_dir='gtFine/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py new file mode 100644 index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (769, 769) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2049, 1025), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/drive.py b/annotator/uniformer/configs/_base_/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/drive.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'DRIVEDataset' +data_root = 'data/DRIVE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (584, 565) +crop_size = (64, 64) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/hrf.py b/annotator/uniformer/configs/_base_/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/hrf.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'HRFDataset' +data_root = 'data/HRF' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (2336, 3504) +crop_size = (256, 256) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context.py b/annotator/uniformer/configs/_base_/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_context.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py new file mode 100644 index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset59' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py @@ -0,0 +1,57 @@ +# dataset settings +dataset_type = 'PascalVOCDataset' +data_root = 'data/VOCdevkit/VOC2012' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py @@ -0,0 +1,9 @@ +_base_ = './pascal_voc12.py' +# dataset settings +data = dict( + train=dict( + ann_dir=['SegmentationClass', 'SegmentationClassAug'], + split=[ + 'ImageSets/Segmentation/train.txt', + 'ImageSets/Segmentation/aug.txt' + ])) diff --git a/annotator/uniformer/configs/_base_/datasets/stare.py b/annotator/uniformer/configs/_base_/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/stare.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'STAREDataset' +data_root = 'data/STARE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (605, 700) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/default_runtime.py b/annotator/uniformer/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b --- /dev/null +++ b/annotator/uniformer/configs/_base_/default_runtime.py @@ -0,0 +1,14 @@ +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True diff --git a/annotator/uniformer/configs/_base_/models/ann_r50-d8.py b/annotator/uniformer/configs/_base_/models/ann_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..a2cb653827e44e6015b3b83bc578003e614a6aa1 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/ann_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ANNHead', + in_channels=[1024, 2048], + in_index=[2, 3], + channels=512, + project_channels=256, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f5316cbcf3896ba9de7ca2c801eba512f01d5e --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='APCHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..794148f576b9e215c3c6963e73dffe98204b7717 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='CCHead', + in_channels=2048, + in_index=3, + channels=512, + recurrence=2, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/cgnet.py b/annotator/uniformer/configs/_base_/models/cgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..eff8d9458c877c5db894957e0b1b4597e40da6ab --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/cgnet.py @@ -0,0 +1,35 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='CGNet', + norm_cfg=norm_cfg, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16)), + decode_head=dict( + type='FCNHead', + in_channels=256, + in_index=2, + channels=256, + num_convs=0, + concat_input=False, + dropout_ratio=0, + num_classes=19, + norm_cfg=norm_cfg, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + class_weight=[ + 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352, + 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905, + 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587, + 10.396974, 10.055647 + ])), + # model training and testing settings + train_cfg=dict(sampler=None), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/danet_r50-d8.py b/annotator/uniformer/configs/_base_/models/danet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..2c934939fac48525f22ad86f489a041dd7db7d09 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/danet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DAHead', + in_channels=2048, + in_index=3, + channels=512, + pam_channels=64, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py b/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a43bee01422ad4795dd27874e0cd4bb6cbfecf --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd262999d8b2cb8e14a5c32190ae73f479d8e81 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py @@ -0,0 +1,50 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='ASPPHead', + in_channels=64, + in_index=4, + channels=16, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py b/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..050e39e091d816df9028d23aa3ecf9db74e441e1 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DepthwiseSeparableASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + c1_in_channels=256, + c1_channels=48, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..d22ba52640bebd805b3b8d07025e276dfb023759 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DMHead', + in_channels=2048, + in_index=3, + channels=512, + filter_sizes=(1, 3, 5, 7), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py b/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..edb4c174c51e34c103737ba39bfc48bf831e561d --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DNLHead', + in_channels=2048, + in_index=3, + channels=512, + dropout_ratio=0.1, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py b/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..26adcd430926de0862204a71d345f2543167f27b --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EMAHead', + in_channels=2048, + in_index=3, + channels=256, + ema_channels=512, + num_bases=64, + num_stages=3, + momentum=0.1, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..be777123a886503172a95fe0719e956a147bbd68 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py @@ -0,0 +1,48 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EncHead', + in_channels=[512, 1024, 2048], + in_index=(1, 2, 3), + channels=512, + num_codes=32, + use_se_loss=True, + add_lateral=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_se_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/fast_scnn.py b/annotator/uniformer/configs/_base_/models/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..32fdeb659355a5ce5ef2cc7c2f30742703811cdf --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/fast_scnn.py @@ -0,0 +1,57 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='FastSCNN', + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + norm_cfg=norm_cfg, + align_corners=False), + decode_head=dict( + type='DepthwiseSeparableFCNHead', + in_channels=128, + channels=128, + concat_input=False, + num_classes=19, + in_index=-1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=128, + channels=32, + num_convs=1, + num_classes=19, + in_index=-2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=64, + channels=32, + num_convs=1, + num_classes=19, + in_index=-3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/fcn_hr18.py b/annotator/uniformer/configs/_base_/models/fcn_hr18.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e299bc89ada56ca14bbffcbdb08a586b8ed9e9 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/fcn_hr18.py @@ -0,0 +1,52 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://msra/hrnetv2_w18', + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + channels=sum([18, 36, 72, 144]), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py b/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..5e98f6cc918b6146fc6d613c6918e825ef1355c3 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py @@ -0,0 +1,45 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='FCNHead', + in_channels=2048, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py new file mode 100644 index 0000000000000000000000000000000000000000..a33e7972877f902d0e7d18401ca675e3e4e60a18 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py @@ -0,0 +1,51 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='FCNHead', + in_channels=64, + in_index=4, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/annotator/uniformer/configs/_base_/models/fpn_r50.py b/annotator/uniformer/configs/_base_/models/fpn_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..86ab327db92e44c14822d65f1c9277cb007f17c1 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/fpn_r50.py @@ -0,0 +1,36 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/fpn_uniformer.py b/annotator/uniformer/configs/_base_/models/fpn_uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8aae98c5991055bfcc08e82ccdc09f8b1d9f8a8d --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/fpn_uniformer.py @@ -0,0 +1,35 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole') +) diff --git a/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2ad69f5c22adfe79d5fdabf920217628987166 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='GCHead', + in_channels=2048, + in_index=3, + channels=512, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py b/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..93258242a90695cc94a7c6bd41562d6a75988771 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py @@ -0,0 +1,25 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='MobileNetV3', + arch='large', + out_indices=(1, 3, 16), + norm_cfg=norm_cfg), + decode_head=dict( + type='LRASPPHead', + in_channels=(16, 24, 960), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py b/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..5674a39854cafd1f2e363bac99c58ccae62f24da --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='NLHead', + in_channels=2048, + in_index=3, + channels=512, + dropout_ratio=0.1, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py b/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py new file mode 100644 index 0000000000000000000000000000000000000000..c60f62a7cdf3f5c5096a7a7e725e8268fddcb057 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py @@ -0,0 +1,68 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://msra/hrnetv2_w18', + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=[ + dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + channels=sum([18, 36, 72, 144]), + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='OCRHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + channels=512, + ocr_channels=256, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..615aa3ff703942b6c22b2d6e9642504dd3e41ebd --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=[ + dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='OCRHead', + in_channels=2048, + in_index=3, + channels=512, + ocr_channels=256, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/pointrend_r50.py b/annotator/uniformer/configs/_base_/models/pointrend_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..9d323dbf9466d41e0800aa57ef84045f3d874bdf --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/pointrend_r50.py @@ -0,0 +1,56 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=[ + dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='PointHead', + in_channels=[256], + in_index=[0], + channels=256, + num_fcs=3, + coarse_pred_each_layer=True, + dropout_ratio=-1, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict( + num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75), + test_cfg=dict( + mode='whole', + subdivision_steps=2, + subdivision_num_points=8196, + scale_factor=2)) diff --git a/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py b/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..689513fa9d2a40f14bf0ae4ae61f38f0dcc1b3da --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py @@ -0,0 +1,49 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='PSAHead', + in_channels=2048, + in_index=3, + channels=512, + mask_size=(97, 97), + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..f451e08ad2eb0732dcb806b1851eb978d4acf136 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='PSPHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py new file mode 100644 index 0000000000000000000000000000000000000000..fcff9ec4f41fad158344ecd77313dc14564f3682 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py @@ -0,0 +1,50 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='PSPHead', + in_channels=64, + in_index=4, + channels=16, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/annotator/uniformer/configs/_base_/models/upernet_r50.py b/annotator/uniformer/configs/_base_/models/upernet_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..10974962fdd7136031fd06de1700f497d355ceaa --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/upernet_r50.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='UPerHead', + in_channels=[256, 512, 1024, 2048], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/annotator/uniformer/configs/_base_/models/upernet_uniformer.py b/annotator/uniformer/configs/_base_/models/upernet_uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..41aa4db809dc6e2c508e98051f61807d07477903 --- /dev/null +++ b/annotator/uniformer/configs/_base_/models/upernet_uniformer.py @@ -0,0 +1,43 @@ +# model settings +norm_cfg = dict(type='BN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1), + decode_head=dict( + type='UPerHead', + in_channels=[64, 128, 320, 512], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=320, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_160k.py b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py new file mode 100644 index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811 --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=160000) +checkpoint_config = dict(by_epoch=False, interval=16000) +evaluation = dict(interval=16000, metric='mIoU') diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_20k.py b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py new file mode 100644 index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=20000) +checkpoint_config = dict(by_epoch=False, interval=2000) +evaluation = dict(interval=2000, metric='mIoU') diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_40k.py b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=40000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU') diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_80k.py b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617 --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=80000) +checkpoint_config = dict(by_epoch=False, interval=8000) +evaluation = dict(interval=8000, metric='mIoU') diff --git a/annotator/uniformer/exp/upernet_global_small/config.py b/annotator/uniformer/exp/upernet_global_small/config.py new file mode 100644 index 0000000000000000000000000000000000000000..01db96bf9b0be531aa0eaf62fee51543712f8670 --- /dev/null +++ b/annotator/uniformer/exp/upernet_global_small/config.py @@ -0,0 +1,38 @@ +_base_ = [ + '../../configs/_base_/models/upernet_uniformer.py', + '../../configs/_base_/datasets/ade20k.py', + '../../configs/_base_/default_runtime.py', + '../../configs/_base_/schedules/schedule_160k.py' +] +model = dict( + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + drop_path_rate=0.25, + windows=False, + hybrid=False + ), + decode_head=dict( + in_channels=[64, 128, 320, 512], + num_classes=150 + ), + auxiliary_head=dict( + in_channels=320, + num_classes=150 + )) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) + +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) + +data=dict(samples_per_gpu=2) \ No newline at end of file diff --git a/annotator/uniformer/exp/upernet_global_small/run.sh b/annotator/uniformer/exp/upernet_global_small/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9fb22edfa7a32624ea08a63fe7d720c40db3b696 --- /dev/null +++ b/annotator/uniformer/exp/upernet_global_small/run.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +work_path=$(dirname $0) +PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=8 \ + tools/train.py ${work_path}/config.py \ + --launcher pytorch \ + --options model.backbone.pretrained_path='your_model_path/uniformer_small_in1k.pth' \ + --work-dir ${work_path}/ckpt \ + 2>&1 | tee -a ${work_path}/log.txt diff --git a/annotator/uniformer/exp/upernet_global_small/test.sh b/annotator/uniformer/exp/upernet_global_small/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..d9a85e7a0d3b7c96b060f473d41254b37a382fcb --- /dev/null +++ b/annotator/uniformer/exp/upernet_global_small/test.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +work_path=$(dirname $0) +PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=8 \ + tools/test.py ${work_path}/test_config_h32.py \ + ${work_path}/ckpt/latest.pth \ + --launcher pytorch \ + --eval mIoU \ + 2>&1 | tee -a ${work_path}/log.txt diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_g.py b/annotator/uniformer/exp/upernet_global_small/test_config_g.py new file mode 100644 index 0000000000000000000000000000000000000000..e43737a98a3b174a9f2fe059c06d511144686459 --- /dev/null +++ b/annotator/uniformer/exp/upernet_global_small/test_config_g.py @@ -0,0 +1,38 @@ +_base_ = [ + '../../configs/_base_/models/upernet_uniformer.py', + '../../configs/_base_/datasets/ade20k.py', + '../../configs/_base_/default_runtime.py', + '../../configs/_base_/schedules/schedule_160k.py' +] +model = dict( + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + drop_path_rate=0.25, + windows=False, + hybrid=False, + ), + decode_head=dict( + in_channels=[64, 128, 320, 512], + num_classes=150 + ), + auxiliary_head=dict( + in_channels=320, + num_classes=150 + )) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) + +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) + +data=dict(samples_per_gpu=2) \ No newline at end of file diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_h32.py b/annotator/uniformer/exp/upernet_global_small/test_config_h32.py new file mode 100644 index 0000000000000000000000000000000000000000..a31e3874f76f9f7b089ac8834d85df2441af9b0e --- /dev/null +++ b/annotator/uniformer/exp/upernet_global_small/test_config_h32.py @@ -0,0 +1,39 @@ +_base_ = [ + '../../configs/_base_/models/upernet_uniformer.py', + '../../configs/_base_/datasets/ade20k.py', + '../../configs/_base_/default_runtime.py', + '../../configs/_base_/schedules/schedule_160k.py' +] +model = dict( + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + drop_path_rate=0.25, + windows=False, + hybrid=True, + window_size=32 + ), + decode_head=dict( + in_channels=[64, 128, 320, 512], + num_classes=150 + ), + auxiliary_head=dict( + in_channels=320, + num_classes=150 + )) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) + +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) + +data=dict(samples_per_gpu=2) \ No newline at end of file diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_w32.py b/annotator/uniformer/exp/upernet_global_small/test_config_w32.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9e06f029e46c14cb9ddb39319cabe86fef9b44 --- /dev/null +++ b/annotator/uniformer/exp/upernet_global_small/test_config_w32.py @@ -0,0 +1,39 @@ +_base_ = [ + '../../configs/_base_/models/upernet_uniformer.py', + '../../configs/_base_/datasets/ade20k.py', + '../../configs/_base_/default_runtime.py', + '../../configs/_base_/schedules/schedule_160k.py' +] +model = dict( + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + drop_path_rate=0.25, + windows=True, + hybrid=False, + window_size=32 + ), + decode_head=dict( + in_channels=[64, 128, 320, 512], + num_classes=150 + ), + auxiliary_head=dict( + in_channels=320, + num_classes=150 + )) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) + +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) + +data=dict(samples_per_gpu=2) \ No newline at end of file diff --git a/annotator/uniformer/mmcv/__init__.py b/annotator/uniformer/mmcv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03 --- /dev/null +++ b/annotator/uniformer/mmcv/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# flake8: noqa +from .arraymisc import * +from .fileio import * +from .image import * +from .utils import * +from .version import * +from .video import * +from .visualization import * + +# The following modules are not imported to this level, so mmcv may be used +# without PyTorch. +# - runner +# - parallel +# - op diff --git a/annotator/uniformer/mmcv/arraymisc/__init__.py b/annotator/uniformer/mmcv/arraymisc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c --- /dev/null +++ b/annotator/uniformer/mmcv/arraymisc/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .quantization import dequantize, quantize + +__all__ = ['quantize', 'dequantize'] diff --git a/annotator/uniformer/mmcv/arraymisc/quantization.py b/annotator/uniformer/mmcv/arraymisc/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186 --- /dev/null +++ b/annotator/uniformer/mmcv/arraymisc/quantization.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum( + np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - + min_val) / levels + min_val + + return dequantized_arr diff --git a/annotator/uniformer/mmcv/cnn/__init__.py b/annotator/uniformer/mmcv/cnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .alexnet import AlexNet +# yapf: disable +from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, + PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, + ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule, + ConvTranspose2d, ConvTranspose3d, ConvWS2d, + DepthwiseSeparableConvModule, GeneralizedAttention, + HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d, + NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish, + build_activation_layer, build_conv_layer, + build_norm_layer, build_padding_layer, build_plugin_layer, + build_upsample_layer, conv_ws_2d, is_norm) +from .builder import MODELS, build_model_from_cfg +# yapf: enable +from .resnet import ResNet, make_res_layer +from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, + NormalInit, PretrainedInit, TruncNormalInit, UniformInit, + XavierInit, bias_init_with_prob, caffe2_xavier_init, + constant_init, fuse_conv_bn, get_model_complexity_info, + initialize, kaiming_init, normal_init, trunc_normal_init, + uniform_init, xavier_init) +from .vgg import VGG, make_vgg_layer + +__all__ = [ + 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', + 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', + 'bias_init_with_prob', 'ConvModule', 'build_activation_layer', + 'build_conv_layer', 'build_norm_layer', 'build_padding_layer', + 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d', + 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish', + 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', + 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', + 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d', + 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', + 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', + 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg' +] diff --git a/annotator/uniformer/mmcv/cnn/alexnet.py b/annotator/uniformer/mmcv/cnn/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/alexnet.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.nn as nn + + +class AlexNet(nn.Module): + """AlexNet backbone. + + Args: + num_classes (int): number of classes for classification. + """ + + def __init__(self, num_classes=-1): + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + # use default initializer + pass + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return x diff --git a/annotator/uniformer/mmcv/cnn/bricks/__init__.py b/annotator/uniformer/mmcv/cnn/bricks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .activation import build_activation_layer +from .context_block import ContextBlock +from .conv import build_conv_layer +from .conv2d_adaptive_padding import Conv2dAdaptivePadding +from .conv_module import ConvModule +from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d +from .depthwise_separable_conv_module import DepthwiseSeparableConvModule +from .drop import Dropout, DropPath +from .generalized_attention import GeneralizedAttention +from .hsigmoid import HSigmoid +from .hswish import HSwish +from .non_local import NonLocal1d, NonLocal2d, NonLocal3d +from .norm import build_norm_layer, is_norm +from .padding import build_padding_layer +from .plugin import build_plugin_layer +from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, + PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS) +from .scale import Scale +from .swish import Swish +from .upsample import build_upsample_layer +from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, + Linear, MaxPool2d, MaxPool3d) + +__all__ = [ + 'ConvModule', 'build_activation_layer', 'build_conv_layer', + 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer', + 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d', + 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention', + 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', + 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', + 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', + 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', + 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath' +] diff --git a/annotator/uniformer/mmcv/cnn/bricks/activation.py b/annotator/uniformer/mmcv/cnn/bricks/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..cab2712287d5ef7be2f079dcb54a94b96394eab5 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/activation.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version +from .registry import ACTIVATION_LAYERS + +for module in [ + nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU, + nn.Sigmoid, nn.Tanh +]: + ACTIVATION_LAYERS.register_module(module=module) + + +@ACTIVATION_LAYERS.register_module(name='Clip') +@ACTIVATION_LAYERS.register_module() +class Clamp(nn.Module): + """Clamp activation layer. + + This activation function is to clamp the feature map value within + :math:`[min, max]`. More details can be found in ``torch.clamp()``. + + Args: + min (Number | optional): Lower-bound of the range to be clamped to. + Default to -1. + max (Number | optional): Upper-bound of the range to be clamped to. + Default to 1. + """ + + def __init__(self, min=-1., max=1.): + super(Clamp, self).__init__() + self.min = min + self.max = max + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: Clamped tensor. + """ + return torch.clamp(x, min=self.min, max=self.max) + + +class GELU(nn.Module): + r"""Applies the Gaussian Error Linear Units function: + + .. math:: + \text{GELU}(x) = x * \Phi(x) + where :math:`\Phi(x)` is the Cumulative Distribution Function for + Gaussian Distribution. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input): + return F.gelu(input) + + +if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.4')): + ACTIVATION_LAYERS.register_module(module=GELU) +else: + ACTIVATION_LAYERS.register_module(module=nn.GELU) + + +def build_activation_layer(cfg): + """Build activation layer. + + Args: + cfg (dict): The activation layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate an activation layer. + + Returns: + nn.Module: Created activation layer. + """ + return build_from_cfg(cfg, ACTIVATION_LAYERS) diff --git a/annotator/uniformer/mmcv/cnn/bricks/context_block.py b/annotator/uniformer/mmcv/cnn/bricks/context_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/context_block.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..utils import constant_init, kaiming_init +from .registry import PLUGIN_LAYERS + + +def last_zero_init(m): + if isinstance(m, nn.Sequential): + constant_init(m[-1], val=0) + else: + constant_init(m, val=0) + + +@PLUGIN_LAYERS.register_module() +class ContextBlock(nn.Module): + """ContextBlock module in GCNet. + + See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond' + (https://arxiv.org/abs/1904.11492) for details. + + Args: + in_channels (int): Channels of the input feature map. + ratio (float): Ratio of channels of transform bottleneck + pooling_type (str): Pooling method for context modeling. + Options are 'att' and 'avg', stand for attention pooling and + average pooling respectively. Default: 'att'. + fusion_types (Sequence[str]): Fusion method for feature fusion, + Options are 'channels_add', 'channel_mul', stand for channelwise + addition and multiplication respectively. Default: ('channel_add',) + """ + + _abbr_ = 'context_block' + + def __init__(self, + in_channels, + ratio, + pooling_type='att', + fusion_types=('channel_add', )): + super(ContextBlock, self).__init__() + assert pooling_type in ['avg', 'att'] + assert isinstance(fusion_types, (list, tuple)) + valid_fusion_types = ['channel_add', 'channel_mul'] + assert all([f in valid_fusion_types for f in fusion_types]) + assert len(fusion_types) > 0, 'at least one fusion should be used' + self.in_channels = in_channels + self.ratio = ratio + self.planes = int(in_channels * ratio) + self.pooling_type = pooling_type + self.fusion_types = fusion_types + if pooling_type == 'att': + self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + else: + self.avg_pool = nn.AdaptiveAvgPool2d(1) + if 'channel_add' in fusion_types: + self.channel_add_conv = nn.Sequential( + nn.Conv2d(self.in_channels, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) + else: + self.channel_add_conv = None + if 'channel_mul' in fusion_types: + self.channel_mul_conv = nn.Sequential( + nn.Conv2d(self.in_channels, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) + else: + self.channel_mul_conv = None + self.reset_parameters() + + def reset_parameters(self): + if self.pooling_type == 'att': + kaiming_init(self.conv_mask, mode='fan_in') + self.conv_mask.inited = True + + if self.channel_add_conv is not None: + last_zero_init(self.channel_add_conv) + if self.channel_mul_conv is not None: + last_zero_init(self.channel_mul_conv) + + def spatial_pool(self, x): + batch, channel, height, width = x.size() + if self.pooling_type == 'att': + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + context_mask = self.conv_mask(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = self.softmax(context_mask) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + else: + # [N, C, 1, 1] + context = self.avg_pool(x) + + return context + + def forward(self, x): + # [N, C, 1, 1] + context = self.spatial_pool(x) + + out = x + if self.channel_mul_conv is not None: + # [N, C, 1, 1] + channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) + out = out * channel_mul_term + if self.channel_add_conv is not None: + # [N, C, 1, 1] + channel_add_term = self.channel_add_conv(context) + out = out + channel_add_term + + return out diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv.py b/annotator/uniformer/mmcv/cnn/bricks/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from .registry import CONV_LAYERS + +CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d) +CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d) +CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d) +CONV_LAYERS.register_module('Conv', module=nn.Conv2d) + + +def build_conv_layer(cfg, *args, **kwargs): + """Build convolution layer. + + Args: + cfg (None or dict): The conv layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate an conv layer. + args (argument list): Arguments passed to the `__init__` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the `__init__` + method of the corresponding conv layer. + + Returns: + nn.Module: Created conv layer. + """ + if cfg is None: + cfg_ = dict(type='Conv2d') + else: + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in CONV_LAYERS: + raise KeyError(f'Unrecognized norm type {layer_type}') + else: + conv_layer = CONV_LAYERS.get(layer_type) + + layer = conv_layer(*args, **kwargs, **cfg_) + + return layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py new file mode 100644 index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from torch import nn +from torch.nn import functional as F + +from .registry import CONV_LAYERS + + +@CONV_LAYERS.register_module() +class Conv2dAdaptivePadding(nn.Conv2d): + """Implementation of 2D convolution in tensorflow with `padding` as "same", + which applies padding to input (if needed) so that input image gets fully + covered by filter and stride you specified. For stride 1, this will ensure + that output image size is same as input. For stride of 2, output dimensions + will be half, for example. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, + dilation, groups, bias) + + def forward(self, x): + img_h, img_w = x.size()[-2:] + kernel_h, kernel_w = self.weight.size()[-2:] + stride_h, stride_w = self.stride + output_h = math.ceil(img_h / stride_h) + output_w = math.ceil(img_w / stride_w) + pad_h = ( + max((output_h - 1) * self.stride[0] + + (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0)) + pad_w = ( + max((output_w - 1) * self.stride[1] + + (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0)) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py new file mode 100644 index 0000000000000000000000000000000000000000..e60e7e62245071c77b652093fddebff3948d7c3e --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn + +from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm +from ..utils import constant_init, kaiming_init +from .activation import build_activation_layer +from .conv import build_conv_layer +from .norm import build_norm_layer +from .padding import build_padding_layer +from .registry import PLUGIN_LAYERS + + +@PLUGIN_LAYERS.register_module() +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = 'conv_block' + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=True, + with_spectral_norm=False, + padding_mode='zeros', + order=('conv', 'norm', 'act')): + super(ConvModule, self).__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert act_cfg is None or isinstance(act_cfg, dict) + official_padding_mode = ['zeros', 'circular'] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(['conv', 'norm', 'act']) + + self.with_norm = norm_cfg is not None + self.with_activation = act_cfg is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == 'auto': + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + pad_cfg = dict(type=padding_mode) + self.padding_layer = build_padding_layer(pad_cfg, padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.add_module(self.norm_name, norm) + if self.with_bias: + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn( + 'Unnecessary conv bias before batch/instance norm') + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + act_cfg_ = act_cfg.copy() + # nn.Tanh has no 'inplace' argument + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish' + ]: + act_cfg_.setdefault('inplace', inplace) + self.activate = build_activation_layer(act_cfg_) + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, 'init_weights'): + if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': + nonlinearity = 'leaky_relu' + a = self.act_cfg.get('negative_slope', 0.01) + else: + nonlinearity = 'relu' + a = 0 + kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) + if self.with_norm: + constant_init(self.norm, 1, bias=0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == 'conv': + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and activate and self.with_activation: + x = self.activate(x) + return x diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py new file mode 100644 index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import CONV_LAYERS + + +def conv_ws_2d(input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + eps=1e-5): + c_in = weight.size(0) + weight_flat = weight.view(c_in, -1) + mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) + std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) + weight = (weight - mean) / (std + eps) + return F.conv2d(input, weight, bias, stride, padding, dilation, groups) + + +@CONV_LAYERS.register_module('ConvWS') +class ConvWS2d(nn.Conv2d): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + eps=1e-5): + super(ConvWS2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.eps = eps + + def forward(self, x): + return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.eps) + + +@CONV_LAYERS.register_module(name='ConvAWS') +class ConvAWS2d(nn.Conv2d): + """AWS (Adaptive Weight Standardization) + + This is a variant of Weight Standardization + (https://arxiv.org/pdf/1903.10520.pdf) + It is used in DetectoRS to avoid NaN + (https://arxiv.org/pdf/2006.02334.pdf) + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the conv kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If set True, adds a learnable bias to the + output. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.register_buffer('weight_gamma', + torch.ones(self.out_channels, 1, 1, 1)) + self.register_buffer('weight_beta', + torch.zeros(self.out_channels, 1, 1, 1)) + + def _get_weight(self, weight): + weight_flat = weight.view(weight.size(0), -1) + mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) + std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) + weight = (weight - mean) / std + weight = self.weight_gamma * weight + self.weight_beta + return weight + + def forward(self, x): + weight = self._get_weight(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override default load function. + + AWS overrides the function _load_from_state_dict to recover + weight_gamma and weight_beta if they are missing. If weight_gamma and + weight_beta are found in the checkpoint, this function will return + after super()._load_from_state_dict. Otherwise, it will compute the + mean and std of the pretrained weights and store them in weight_beta + and weight_gamma. + """ + + self.weight_gamma.data.fill_(-1) + local_missing_keys = [] + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, local_missing_keys, + unexpected_keys, error_msgs) + if self.weight_gamma.data.mean() > 0: + for k in local_missing_keys: + missing_keys.append(k) + return + weight = self.weight.data + weight_flat = weight.view(weight.size(0), -1) + mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) + std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) + self.weight_beta.data.copy_(mean) + self.weight_gamma.data.copy_(std) + missing_gamma_beta = [ + k for k in local_missing_keys + if k.endswith('weight_gamma') or k.endswith('weight_beta') + ] + for k in missing_gamma_beta: + local_missing_keys.remove(k) + for k in local_missing_keys: + missing_keys.append(k) diff --git a/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py new file mode 100644 index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .conv_module import ConvModule + + +class DepthwiseSeparableConvModule(nn.Module): + """Depthwise separable convolution module. + + See https://arxiv.org/pdf/1704.04861.pdf for details. + + This module can replace a ConvModule with the conv block replaced by two + conv block: depthwise conv block and pointwise conv block. The depthwise + conv block contains depthwise-conv/norm/activation layers. The pointwise + conv block contains pointwise-conv/norm/activation layers. It should be + noted that there will be norm/activation layer in the depthwise conv block + if `norm_cfg` and `act_cfg` are specified. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. Default: 1. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. Default: 0. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. Default: 1. + norm_cfg (dict): Default norm config for both depthwise ConvModule and + pointwise ConvModule. Default: None. + act_cfg (dict): Default activation config for both depthwise ConvModule + and pointwise ConvModule. Default: dict(type='ReLU'). + dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is + 'default', it will be the same as `norm_cfg`. Default: 'default'. + dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: 'default'. + pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is + 'default', it will be the same as `norm_cfg`. Default: 'default'. + pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: 'default'. + kwargs (optional): Other shared arguments for depthwise and pointwise + ConvModule. See ConvModule for ref. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dw_norm_cfg='default', + dw_act_cfg='default', + pw_norm_cfg='default', + pw_act_cfg='default', + **kwargs): + super(DepthwiseSeparableConvModule, self).__init__() + assert 'groups' not in kwargs, 'groups should not be specified' + + # if norm/activation config of depthwise/pointwise ConvModule is not + # specified, use default config. + dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg + dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg + pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg + pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg + + # depthwise convolution + self.depthwise_conv = ConvModule( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + norm_cfg=dw_norm_cfg, + act_cfg=dw_act_cfg, + **kwargs) + + self.pointwise_conv = ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=pw_norm_cfg, + act_cfg=pw_act_cfg, + **kwargs) + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x diff --git a/annotator/uniformer/mmcv/cnn/bricks/drop.py b/annotator/uniformer/mmcv/cnn/bricks/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b4fccd457a0d51fb10c789df3c8537fe7b67c1 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/drop.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from annotator.uniformer.mmcv import build_from_cfg +from .registry import DROPOUT_LAYERS + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + + We follow the implementation + https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + # handle tensors with different dimensions, not just 4D tensors. + shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + output = x.div(keep_prob) * random_tensor.floor() + return output + + +@DROPOUT_LAYERS.register_module() +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + + We follow the implementation + https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 + + Args: + drop_prob (float): Probability of the path to be zeroed. Default: 0.1 + """ + + def __init__(self, drop_prob=0.1): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +@DROPOUT_LAYERS.register_module() +class Dropout(nn.Dropout): + """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of + ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with + ``DropPath`` + + Args: + drop_prob (float): Probability of the elements to be + zeroed. Default: 0.5. + inplace (bool): Do the operation inplace or not. Default: False. + """ + + def __init__(self, drop_prob=0.5, inplace=False): + super().__init__(p=drop_prob, inplace=inplace) + + +def build_dropout(cfg, default_args=None): + """Builder for drop out layers.""" + return build_from_cfg(cfg, DROPOUT_LAYERS, default_args) diff --git a/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py @@ -0,0 +1,412 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import kaiming_init +from .registry import PLUGIN_LAYERS + + +@PLUGIN_LAYERS.register_module() +class GeneralizedAttention(nn.Module): + """GeneralizedAttention module. + + See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks' + (https://arxiv.org/abs/1711.07971) for details. + + Args: + in_channels (int): Channels of the input feature map. + spatial_range (int): The spatial range. -1 indicates no spatial range + constraint. Default: -1. + num_heads (int): The head number of empirical_attention module. + Default: 9. + position_embedding_dim (int): The position embedding dimension. + Default: -1. + position_magnitude (int): A multiplier acting on coord difference. + Default: 1. + kv_stride (int): The feature stride acting on key/value feature map. + Default: 2. + q_stride (int): The feature stride acting on query feature map. + Default: 1. + attention_type (str): A binary indicator string for indicating which + items in generalized empirical_attention module are used. + Default: '1111'. + + - '1000' indicates 'query and key content' (appr - appr) item, + - '0100' indicates 'query content and relative position' + (appr - position) item, + - '0010' indicates 'key content only' (bias - appr) item, + - '0001' indicates 'relative position only' (bias - position) item. + """ + + _abbr_ = 'gen_attention_block' + + def __init__(self, + in_channels, + spatial_range=-1, + num_heads=9, + position_embedding_dim=-1, + position_magnitude=1, + kv_stride=2, + q_stride=1, + attention_type='1111'): + + super(GeneralizedAttention, self).__init__() + + # hard range means local range for non-local operation + self.position_embedding_dim = ( + position_embedding_dim + if position_embedding_dim > 0 else in_channels) + + self.position_magnitude = position_magnitude + self.num_heads = num_heads + self.in_channels = in_channels + self.spatial_range = spatial_range + self.kv_stride = kv_stride + self.q_stride = q_stride + self.attention_type = [bool(int(_)) for _ in attention_type] + self.qk_embed_dim = in_channels // num_heads + out_c = self.qk_embed_dim * num_heads + + if self.attention_type[0] or self.attention_type[1]: + self.query_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_c, + kernel_size=1, + bias=False) + self.query_conv.kaiming_init = True + + if self.attention_type[0] or self.attention_type[2]: + self.key_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_c, + kernel_size=1, + bias=False) + self.key_conv.kaiming_init = True + + self.v_dim = in_channels // num_heads + self.value_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=self.v_dim * num_heads, + kernel_size=1, + bias=False) + self.value_conv.kaiming_init = True + + if self.attention_type[1] or self.attention_type[3]: + self.appr_geom_fc_x = nn.Linear( + self.position_embedding_dim // 2, out_c, bias=False) + self.appr_geom_fc_x.kaiming_init = True + + self.appr_geom_fc_y = nn.Linear( + self.position_embedding_dim // 2, out_c, bias=False) + self.appr_geom_fc_y.kaiming_init = True + + if self.attention_type[2]: + stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2) + appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv + self.appr_bias = nn.Parameter(appr_bias_value) + + if self.attention_type[3]: + stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2) + geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv + self.geom_bias = nn.Parameter(geom_bias_value) + + self.proj_conv = nn.Conv2d( + in_channels=self.v_dim * num_heads, + out_channels=in_channels, + kernel_size=1, + bias=True) + self.proj_conv.kaiming_init = True + self.gamma = nn.Parameter(torch.zeros(1)) + + if self.spatial_range >= 0: + # only works when non local is after 3*3 conv + if in_channels == 256: + max_len = 84 + elif in_channels == 512: + max_len = 42 + + max_len_kv = int((max_len - 1.0) / self.kv_stride + 1) + local_constraint_map = np.ones( + (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int) + for iy in range(max_len): + for ix in range(max_len): + local_constraint_map[ + iy, ix, + max((iy - self.spatial_range) // + self.kv_stride, 0):min((iy + self.spatial_range + + 1) // self.kv_stride + + 1, max_len), + max((ix - self.spatial_range) // + self.kv_stride, 0):min((ix + self.spatial_range + + 1) // self.kv_stride + + 1, max_len)] = 0 + + self.local_constraint_map = nn.Parameter( + torch.from_numpy(local_constraint_map).byte(), + requires_grad=False) + + if self.q_stride > 1: + self.q_downsample = nn.AvgPool2d( + kernel_size=1, stride=self.q_stride) + else: + self.q_downsample = None + + if self.kv_stride > 1: + self.kv_downsample = nn.AvgPool2d( + kernel_size=1, stride=self.kv_stride) + else: + self.kv_downsample = None + + self.init_weights() + + def get_position_embedding(self, + h, + w, + h_kv, + w_kv, + q_stride, + kv_stride, + device, + dtype, + feat_dim, + wave_length=1000): + # the default type of Tensor is float32, leading to type mismatch + # in fp16 mode. Cast it to support fp16 mode. + h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype) + h_idxs = h_idxs.view((h, 1)) * q_stride + + w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype) + w_idxs = w_idxs.view((w, 1)) * q_stride + + h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to( + device=device, dtype=dtype) + h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride + + w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to( + device=device, dtype=dtype) + w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride + + # (h, h_kv, 1) + h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0) + h_diff *= self.position_magnitude + + # (w, w_kv, 1) + w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0) + w_diff *= self.position_magnitude + + feat_range = torch.arange(0, feat_dim / 4).to( + device=device, dtype=dtype) + + dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype) + dim_mat = dim_mat**((4. / feat_dim) * feat_range) + dim_mat = dim_mat.view((1, 1, -1)) + + embedding_x = torch.cat( + ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2) + + embedding_y = torch.cat( + ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2) + + return embedding_x, embedding_y + + def forward(self, x_input): + num_heads = self.num_heads + + # use empirical_attention + if self.q_downsample is not None: + x_q = self.q_downsample(x_input) + else: + x_q = x_input + n, _, h, w = x_q.shape + + if self.kv_downsample is not None: + x_kv = self.kv_downsample(x_input) + else: + x_kv = x_input + _, _, h_kv, w_kv = x_kv.shape + + if self.attention_type[0] or self.attention_type[1]: + proj_query = self.query_conv(x_q).view( + (n, num_heads, self.qk_embed_dim, h * w)) + proj_query = proj_query.permute(0, 1, 3, 2) + + if self.attention_type[0] or self.attention_type[2]: + proj_key = self.key_conv(x_kv).view( + (n, num_heads, self.qk_embed_dim, h_kv * w_kv)) + + if self.attention_type[1] or self.attention_type[3]: + position_embed_x, position_embed_y = self.get_position_embedding( + h, w, h_kv, w_kv, self.q_stride, self.kv_stride, + x_input.device, x_input.dtype, self.position_embedding_dim) + # (n, num_heads, w, w_kv, dim) + position_feat_x = self.appr_geom_fc_x(position_embed_x).\ + view(1, w, w_kv, num_heads, self.qk_embed_dim).\ + permute(0, 3, 1, 2, 4).\ + repeat(n, 1, 1, 1, 1) + + # (n, num_heads, h, h_kv, dim) + position_feat_y = self.appr_geom_fc_y(position_embed_y).\ + view(1, h, h_kv, num_heads, self.qk_embed_dim).\ + permute(0, 3, 1, 2, 4).\ + repeat(n, 1, 1, 1, 1) + + position_feat_x /= math.sqrt(2) + position_feat_y /= math.sqrt(2) + + # accelerate for saliency only + if (np.sum(self.attention_type) == 1) and self.attention_type[2]: + appr_bias = self.appr_bias.\ + view(1, num_heads, 1, self.qk_embed_dim).\ + repeat(n, 1, 1, 1) + + energy = torch.matmul(appr_bias, proj_key).\ + view(n, num_heads, 1, h_kv * w_kv) + + h = 1 + w = 1 + else: + # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for + if not self.attention_type[0]: + energy = torch.zeros( + n, + num_heads, + h, + w, + h_kv, + w_kv, + dtype=x_input.dtype, + device=x_input.device) + + # attention_type[0]: appr - appr + # attention_type[1]: appr - position + # attention_type[2]: bias - appr + # attention_type[3]: bias - position + if self.attention_type[0] or self.attention_type[2]: + if self.attention_type[0] and self.attention_type[2]: + appr_bias = self.appr_bias.\ + view(1, num_heads, 1, self.qk_embed_dim) + energy = torch.matmul(proj_query + appr_bias, proj_key).\ + view(n, num_heads, h, w, h_kv, w_kv) + + elif self.attention_type[0]: + energy = torch.matmul(proj_query, proj_key).\ + view(n, num_heads, h, w, h_kv, w_kv) + + elif self.attention_type[2]: + appr_bias = self.appr_bias.\ + view(1, num_heads, 1, self.qk_embed_dim).\ + repeat(n, 1, 1, 1) + + energy += torch.matmul(appr_bias, proj_key).\ + view(n, num_heads, 1, 1, h_kv, w_kv) + + if self.attention_type[1] or self.attention_type[3]: + if self.attention_type[1] and self.attention_type[3]: + geom_bias = self.geom_bias.\ + view(1, num_heads, 1, self.qk_embed_dim) + + proj_query_reshape = (proj_query + geom_bias).\ + view(n, num_heads, h, w, self.qk_embed_dim) + + energy_x = torch.matmul( + proj_query_reshape.permute(0, 1, 3, 2, 4), + position_feat_x.permute(0, 1, 2, 4, 3)) + energy_x = energy_x.\ + permute(0, 1, 3, 2, 4).unsqueeze(4) + + energy_y = torch.matmul( + proj_query_reshape, + position_feat_y.permute(0, 1, 2, 4, 3)) + energy_y = energy_y.unsqueeze(5) + + energy += energy_x + energy_y + + elif self.attention_type[1]: + proj_query_reshape = proj_query.\ + view(n, num_heads, h, w, self.qk_embed_dim) + proj_query_reshape = proj_query_reshape.\ + permute(0, 1, 3, 2, 4) + position_feat_x_reshape = position_feat_x.\ + permute(0, 1, 2, 4, 3) + position_feat_y_reshape = position_feat_y.\ + permute(0, 1, 2, 4, 3) + + energy_x = torch.matmul(proj_query_reshape, + position_feat_x_reshape) + energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4) + + energy_y = torch.matmul(proj_query_reshape, + position_feat_y_reshape) + energy_y = energy_y.unsqueeze(5) + + energy += energy_x + energy_y + + elif self.attention_type[3]: + geom_bias = self.geom_bias.\ + view(1, num_heads, self.qk_embed_dim, 1).\ + repeat(n, 1, 1, 1) + + position_feat_x_reshape = position_feat_x.\ + view(n, num_heads, w*w_kv, self.qk_embed_dim) + + position_feat_y_reshape = position_feat_y.\ + view(n, num_heads, h * h_kv, self.qk_embed_dim) + + energy_x = torch.matmul(position_feat_x_reshape, geom_bias) + energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv) + + energy_y = torch.matmul(position_feat_y_reshape, geom_bias) + energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1) + + energy += energy_x + energy_y + + energy = energy.view(n, num_heads, h * w, h_kv * w_kv) + + if self.spatial_range >= 0: + cur_local_constraint_map = \ + self.local_constraint_map[:h, :w, :h_kv, :w_kv].\ + contiguous().\ + view(1, 1, h*w, h_kv*w_kv) + + energy = energy.masked_fill_(cur_local_constraint_map, + float('-inf')) + + attention = F.softmax(energy, 3) + + proj_value = self.value_conv(x_kv) + proj_value_reshape = proj_value.\ + view((n, num_heads, self.v_dim, h_kv * w_kv)).\ + permute(0, 1, 3, 2) + + out = torch.matmul(attention, proj_value_reshape).\ + permute(0, 1, 3, 2).\ + contiguous().\ + view(n, self.v_dim * self.num_heads, h, w) + + out = self.proj_conv(out) + + # output is downsampled, upsample back to input size + if self.q_downsample is not None: + out = F.interpolate( + out, + size=x_input.shape[2:], + mode='bilinear', + align_corners=False) + + out = self.gamma * out + x_input + return out + + def init_weights(self): + for m in self.modules(): + if hasattr(m, 'kaiming_init') and m.kaiming_init: + kaiming_init( + m, + mode='fan_in', + nonlinearity='leaky_relu', + bias=0, + distribution='uniform', + a=1) diff --git a/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py new file mode 100644 index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class HSigmoid(nn.Module): + """Hard Sigmoid Module. Apply the hard sigmoid function: + Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value) + Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1) + + Args: + bias (float): Bias of the input feature map. Default: 1.0. + divisor (float): Divisor of the input feature map. Default: 2.0. + min_value (float): Lower bound value. Default: 0.0. + max_value (float): Upper bound value. Default: 1.0. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0): + super(HSigmoid, self).__init__() + self.bias = bias + self.divisor = divisor + assert self.divisor != 0 + self.min_value = min_value + self.max_value = max_value + + def forward(self, x): + x = (x + self.bias) / self.divisor + + return x.clamp_(self.min_value, self.max_value) diff --git a/annotator/uniformer/mmcv/cnn/bricks/hswish.py b/annotator/uniformer/mmcv/cnn/bricks/hswish.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/hswish.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class HSwish(nn.Module): + """Hard Swish Module. + + This module applies the hard swish function: + + .. math:: + Hswish(x) = x * ReLU6(x + 3) / 6 + + Args: + inplace (bool): can optionally do the operation in-place. + Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, inplace=False): + super(HSwish, self).__init__() + self.act = nn.ReLU6(inplace) + + def forward(self, x): + return x * self.act(x + 3) / 6 diff --git a/annotator/uniformer/mmcv/cnn/bricks/non_local.py b/annotator/uniformer/mmcv/cnn/bricks/non_local.py new file mode 100644 index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/non_local.py @@ -0,0 +1,306 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta + +import torch +import torch.nn as nn + +from ..utils import constant_init, normal_init +from .conv_module import ConvModule +from .registry import PLUGIN_LAYERS + + +class _NonLocalNd(nn.Module, metaclass=ABCMeta): + """Basic Non-local module. + + This module is proposed in + "Non-local Neural Networks" + Paper reference: https://arxiv.org/abs/1711.07971 + Code reference: https://github.com/AlexHex7/Non-local_pytorch + + Args: + in_channels (int): Channels of the input feature map. + reduction (int): Channel reduction ratio. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`. + Default: True. + conv_cfg (None | dict): The config dict for convolution layers. + If not specified, it will use `nn.Conv2d` for convolution layers. + Default: None. + norm_cfg (None | dict): The config dict for normalization layers. + Default: None. (This parameter is only applicable to conv_out.) + mode (str): Options are `gaussian`, `concatenation`, + `embedded_gaussian` and `dot_product`. Default: embedded_gaussian. + """ + + def __init__(self, + in_channels, + reduction=2, + use_scale=True, + conv_cfg=None, + norm_cfg=None, + mode='embedded_gaussian', + **kwargs): + super(_NonLocalNd, self).__init__() + self.in_channels = in_channels + self.reduction = reduction + self.use_scale = use_scale + self.inter_channels = max(in_channels // reduction, 1) + self.mode = mode + + if mode not in [ + 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation' + ]: + raise ValueError("Mode should be in 'gaussian', 'concatenation', " + f"'embedded_gaussian' or 'dot_product', but got " + f'{mode} instead.') + + # g, theta, phi are defaulted as `nn.ConvNd`. + # Here we use ConvModule for potential usage. + self.g = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.conv_out = ConvModule( + self.inter_channels, + self.in_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + if self.mode != 'gaussian': + self.theta = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.phi = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + + if self.mode == 'concatenation': + self.concat_project = ConvModule( + self.inter_channels * 2, + 1, + kernel_size=1, + stride=1, + padding=0, + bias=False, + act_cfg=dict(type='ReLU')) + + self.init_weights(**kwargs) + + def init_weights(self, std=0.01, zeros_init=True): + if self.mode != 'gaussian': + for m in [self.g, self.theta, self.phi]: + normal_init(m.conv, std=std) + else: + normal_init(self.g.conv, std=std) + if zeros_init: + if self.conv_out.norm_cfg is None: + constant_init(self.conv_out.conv, 0) + else: + constant_init(self.conv_out.norm, 0) + else: + if self.conv_out.norm_cfg is None: + normal_init(self.conv_out.conv, std=std) + else: + normal_init(self.conv_out.norm, std=std) + + def gaussian(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def embedded_gaussian(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= theta_x.shape[-1]**0.5 + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def dot_product(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + pairwise_weight /= pairwise_weight.shape[-1] + return pairwise_weight + + def concatenation(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + h = theta_x.size(2) + w = phi_x.size(3) + theta_x = theta_x.repeat(1, 1, 1, w) + phi_x = phi_x.repeat(1, 1, h, 1) + + concat_feature = torch.cat([theta_x, phi_x], dim=1) + pairwise_weight = self.concat_project(concat_feature) + n, _, h, w = pairwise_weight.size() + pairwise_weight = pairwise_weight.view(n, h, w) + pairwise_weight /= pairwise_weight.shape[-1] + + return pairwise_weight + + def forward(self, x): + # Assume `reduction = 1`, then `inter_channels = C` + # or `inter_channels = C` when `mode="gaussian"` + + # NonLocal1d x: [N, C, H] + # NonLocal2d x: [N, C, H, W] + # NonLocal3d x: [N, C, T, H, W] + n = x.size(0) + + # NonLocal1d g_x: [N, H, C] + # NonLocal2d g_x: [N, HxW, C] + # NonLocal3d g_x: [N, TxHxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H] + # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW] + # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + pairwise_func = getattr(self, self.mode) + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # NonLocal1d y: [N, H, C] + # NonLocal2d y: [N, HxW, C] + # NonLocal3d y: [N, TxHxW, C] + y = torch.matmul(pairwise_weight, g_x) + # NonLocal1d y: [N, C, H] + # NonLocal2d y: [N, C, H, W] + # NonLocal3d y: [N, C, T, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + output = x + self.conv_out(y) + + return output + + +class NonLocal1d(_NonLocalNd): + """1D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv1d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv1d'), + **kwargs): + super(NonLocal1d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool1d(kernel_size=2) + self.g = nn.Sequential(self.g, max_pool_layer) + if self.mode != 'gaussian': + self.phi = nn.Sequential(self.phi, max_pool_layer) + else: + self.phi = max_pool_layer + + +@PLUGIN_LAYERS.register_module() +class NonLocal2d(_NonLocalNd): + """2D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv2d'). + """ + + _abbr_ = 'nonlocal_block' + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv2d'), + **kwargs): + super(NonLocal2d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + self.g = nn.Sequential(self.g, max_pool_layer) + if self.mode != 'gaussian': + self.phi = nn.Sequential(self.phi, max_pool_layer) + else: + self.phi = max_pool_layer + + +class NonLocal3d(_NonLocalNd): + """3D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv3d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv3d'), + **kwargs): + super(NonLocal3d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + self.g = nn.Sequential(self.g, max_pool_layer) + if self.mode != 'gaussian': + self.phi = nn.Sequential(self.phi, max_pool_layer) + else: + self.phi = max_pool_layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/norm.py b/annotator/uniformer/mmcv/cnn/bricks/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..408f4b42731b19a3beeef68b6a5e610d0bbc18b3 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/norm.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect + +import torch.nn as nn + +from annotator.uniformer.mmcv.utils import is_tuple_of +from annotator.uniformer.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm +from .registry import NORM_LAYERS + +NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d) +NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d) +NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d) +NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d) +NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm) +NORM_LAYERS.register_module('GN', module=nn.GroupNorm) +NORM_LAYERS.register_module('LN', module=nn.LayerNorm) +NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d) +NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d) +NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d) +NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d) + + +def infer_abbr(class_type): + """Infer abbreviation from the class name. + + When we build a norm layer with `build_norm_layer()`, we want to preserve + the norm type in variable names, e.g, self.bn1, self.gn. This method will + infer the abbreviation to map class types to abbreviations. + + Rule 1: If the class has the property "_abbr_", return the property. + Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or + InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and + "in" respectively. + Rule 3: If the class name contains "batch", "group", "layer" or "instance", + the abbreviation of this layer will be "bn", "gn", "ln" and "in" + respectively. + Rule 4: Otherwise, the abbreviation falls back to "norm". + + Args: + class_type (type): The norm layer type. + + Returns: + str: The inferred abbreviation. + """ + if not inspect.isclass(class_type): + raise TypeError( + f'class_type must be a type, but got {type(class_type)}') + if hasattr(class_type, '_abbr_'): + return class_type._abbr_ + if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN + return 'in' + elif issubclass(class_type, _BatchNorm): + return 'bn' + elif issubclass(class_type, nn.GroupNorm): + return 'gn' + elif issubclass(class_type, nn.LayerNorm): + return 'ln' + else: + class_name = class_type.__name__.lower() + if 'batch' in class_name: + return 'bn' + elif 'group' in class_name: + return 'gn' + elif 'layer' in class_name: + return 'ln' + elif 'instance' in class_name: + return 'in' + else: + return 'norm_layer' + + +def build_norm_layer(cfg, num_features, postfix=''): + """Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + - requires_grad (bool, optional): Whether stop gradient updates. + num_features (int): Number of input channels. + postfix (int | str): The postfix to be appended into norm abbreviation + to create named layer. + + Returns: + (str, nn.Module): The first element is the layer name consisting of + abbreviation and postfix, e.g., bn1, gn. The second element is the + created norm layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in NORM_LAYERS: + raise KeyError(f'Unrecognized norm type {layer_type}') + + norm_layer = NORM_LAYERS.get(layer_type) + abbr = infer_abbr(norm_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + else: + assert 'num_groups' in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer + + +def is_norm(layer, exclude=None): + """Check if a layer is a normalization layer. + + Args: + layer (nn.Module): The layer to be checked. + exclude (type | tuple[type]): Types to be excluded. + + Returns: + bool: Whether the layer is a norm layer. + """ + if exclude is not None: + if not isinstance(exclude, tuple): + exclude = (exclude, ) + if not is_tuple_of(exclude, type): + raise TypeError( + f'"exclude" must be either None or type or a tuple of types, ' + f'but got {type(exclude)}: {exclude}') + + if exclude and isinstance(layer, exclude): + return False + + all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) + return isinstance(layer, all_norm_bases) diff --git a/annotator/uniformer/mmcv/cnn/bricks/padding.py b/annotator/uniformer/mmcv/cnn/bricks/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/padding.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import PADDING_LAYERS + +PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d) +PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d) +PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d) + + +def build_padding_layer(cfg, *args, **kwargs): + """Build padding layer. + + Args: + cfg (None or dict): The padding layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate a padding layer. + + Returns: + nn.Module: Created padding layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + + cfg_ = cfg.copy() + padding_type = cfg_.pop('type') + if padding_type not in PADDING_LAYERS: + raise KeyError(f'Unrecognized padding type {padding_type}.') + else: + padding_layer = PADDING_LAYERS.get(padding_type) + + layer = padding_layer(*args, **kwargs, **cfg_) + + return layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/plugin.py b/annotator/uniformer/mmcv/cnn/bricks/plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/plugin.py @@ -0,0 +1,88 @@ +import inspect +import platform + +from .registry import PLUGIN_LAYERS + +if platform.system() == 'Windows': + import regex as re +else: + import re + + +def infer_abbr(class_type): + """Infer abbreviation from the class name. + + This method will infer the abbreviation to map class types to + abbreviations. + + Rule 1: If the class has the property "abbr", return the property. + Rule 2: Otherwise, the abbreviation falls back to snake case of class + name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``. + + Args: + class_type (type): The norm layer type. + + Returns: + str: The inferred abbreviation. + """ + + def camel2snack(word): + """Convert camel case word into snack case. + + Modified from `inflection lib + `_. + + Example:: + + >>> camel2snack("FancyBlock") + 'fancy_block' + """ + + word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word) + word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word) + word = word.replace('-', '_') + return word.lower() + + if not inspect.isclass(class_type): + raise TypeError( + f'class_type must be a type, but got {type(class_type)}') + if hasattr(class_type, '_abbr_'): + return class_type._abbr_ + else: + return camel2snack(class_type.__name__) + + +def build_plugin_layer(cfg, postfix='', **kwargs): + """Build plugin layer. + + Args: + cfg (None or dict): cfg should contain: + type (str): identify plugin layer type. + layer args: args needed to instantiate a plugin layer. + postfix (int, str): appended into norm abbreviation to + create named layer. Default: ''. + + Returns: + tuple[str, nn.Module]: + name (str): abbreviation + postfix + layer (nn.Module): created plugin layer + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in PLUGIN_LAYERS: + raise KeyError(f'Unrecognized plugin type {layer_type}') + + plugin_layer = PLUGIN_LAYERS.get(layer_type) + abbr = infer_abbr(plugin_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + layer = plugin_layer(**kwargs, **cfg_) + + return name, layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/registry.py b/annotator/uniformer/mmcv/cnn/bricks/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..39eabc58db4b5954478a2ac1ab91cea5e45ab055 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/registry.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from annotator.uniformer.mmcv.utils import Registry + +CONV_LAYERS = Registry('conv layer') +NORM_LAYERS = Registry('norm layer') +ACTIVATION_LAYERS = Registry('activation layer') +PADDING_LAYERS = Registry('padding layer') +UPSAMPLE_LAYERS = Registry('upsample layer') +PLUGIN_LAYERS = Registry('plugin layer') + +DROPOUT_LAYERS = Registry('drop out layers') +POSITIONAL_ENCODING = Registry('position encoding') +ATTENTION = Registry('attention') +FEEDFORWARD_NETWORK = Registry('feed-forward Network') +TRANSFORMER_LAYER = Registry('transformerLayer') +TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence') diff --git a/annotator/uniformer/mmcv/cnn/bricks/scale.py b/annotator/uniformer/mmcv/cnn/bricks/scale.py new file mode 100644 index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/scale.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +class Scale(nn.Module): + """A learnable scale parameter. + + This layer scales the input by a learnable factor. It multiplies a + learnable scale parameter of shape (1,) with input of any shape. + + Args: + scale (float): Initial value of scale factor. Default: 1.0 + """ + + def __init__(self, scale=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) + + def forward(self, x): + return x * self.scale diff --git a/annotator/uniformer/mmcv/cnn/bricks/swish.py b/annotator/uniformer/mmcv/cnn/bricks/swish.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/swish.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class Swish(nn.Module): + """Swish Module. + + This module applies the swish function: + + .. math:: + Swish(x) = x * Sigmoid(x) + + Returns: + Tensor: The output tensor. + """ + + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) diff --git a/annotator/uniformer/mmcv/cnn/bricks/transformer.py b/annotator/uniformer/mmcv/cnn/bricks/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e61ae0dd941a7be00b3e41a3de833ec50470a45f --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/transformer.py @@ -0,0 +1,595 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings + +import torch +import torch.nn as nn + +from annotator.uniformer.mmcv import ConfigDict, deprecated_api_warning +from annotator.uniformer.mmcv.cnn import Linear, build_activation_layer, build_norm_layer +from annotator.uniformer.mmcv.runner.base_module import BaseModule, ModuleList, Sequential +from annotator.uniformer.mmcv.utils import build_from_cfg +from .drop import build_dropout +from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, + TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) + +# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file +try: + from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 + warnings.warn( + ImportWarning( + '``MultiScaleDeformableAttention`` has been moved to ' + '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501 + '``from annotator.uniformer.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501 + 'to ``from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501 + )) + +except ImportError: + warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' + '``mmcv.ops.multi_scale_deform_attn``, ' + 'You should install ``mmcv-full`` if you need this module. ') + + +def build_positional_encoding(cfg, default_args=None): + """Builder for Position Encoding.""" + return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) + + +def build_attention(cfg, default_args=None): + """Builder for attention.""" + return build_from_cfg(cfg, ATTENTION, default_args) + + +def build_feedforward_network(cfg, default_args=None): + """Builder for feed-forward network (FFN).""" + return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args) + + +def build_transformer_layer(cfg, default_args=None): + """Builder for transformer layer.""" + return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) + + +def build_transformer_layer_sequence(cfg, default_args=None): + """Builder for transformer encoder and transformer decoder.""" + return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) + + +@ATTENTION.register_module() +class MultiheadAttention(BaseModule): + """A wrapper for ``torch.nn.MultiheadAttention``. + + This module implements MultiheadAttention with identity connection, + and positional encoding is also passed as input. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): When it is True, Key, Query and Value are shape of + (batch, n, embed_dim), otherwise (n, batch, embed_dim). + Default to False. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + init_cfg=None, + batch_first=False, + **kwargs): + super(MultiheadAttention, self).__init__(init_cfg) + if 'dropout' in kwargs: + warnings.warn('The arguments `dropout` in MultiheadAttention ' + 'has been deprecated, now you can separately ' + 'set `attn_drop`(float), proj_drop(float), ' + 'and `dropout_layer`(dict) ') + attn_drop = kwargs['dropout'] + dropout_layer['drop_prob'] = kwargs.pop('dropout') + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.batch_first = batch_first + + self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, + **kwargs) + + self.proj_drop = nn.Dropout(proj_drop) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiheadAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `MultiheadAttention`. + + **kwargs allow passing a more general data flow when combining + with other operations in `transformerlayer`. + + Args: + query (Tensor): The input query with shape [num_queries, bs, + embed_dims] if self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + If None, the ``query`` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + identity (Tensor): This tensor, with the same shape as x, + will be used for the identity link. + If None, `x` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query, with + the same shape as `x`. If not None, it will + be added to `x` before forward function. Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + + Returns: + Tensor: forwarded results with shape + [num_queries, bs, embed_dims] + if self.batch_first is False, else + [bs, num_queries embed_dims]. + """ + + if key is None: + key = query + if value is None: + value = key + if identity is None: + identity = query + if key_pos is None: + if query_pos is not None: + # use query_pos if key_pos is not available + if query_pos.shape == key.shape: + key_pos = query_pos + else: + warnings.warn(f'position encoding of key is' + f'missing in {self.__class__.__name__}.') + if query_pos is not None: + query = query + query_pos + if key_pos is not None: + key = key + key_pos + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + out = self.attn( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + +@FEEDFORWARD_NETWORK.register_module() +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning( + { + 'dropout': 'ffn_drop', + 'add_residual': 'add_identity' + }, + cls_name='FFN') + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0., + dropout_layer=None, + add_identity=True, + init_cfg=None, + **kwargs): + super(FFN, self).__init__(init_cfg) + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@TRANSFORMER_LAYER.register_module() +class BaseTransformerLayer(BaseModule): + """Base `TransformerLayer` for vision transformer. + + It can be built from `mmcv.ConfigDict` and support more flexible + customization, for example, using any number of `FFN or LN ` and + use different kinds of `attention` by specifying a list of `ConfigDict` + named `attn_cfgs`. It is worth mentioning that it supports `prenorm` + when you specifying `norm` as the first element of `operation_order`. + More details about the `prenorm`: `On Layer Normalization in the + Transformer Architecture `_ . + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): + Configs for `self_attention` or `cross_attention` modules, + The order of the configs in the list should be consistent with + corresponding attentions in operation_order. + If it is a dict, all of the attention modules in operation_order + will be built with this config. Default: None. + ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): + Configs for FFN, The order of the configs in the list should be + consistent with corresponding ffn in operation_order. + If it is a dict, all of the attention modules in operation_order + will be built with this config. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Support `prenorm` when you specifying first element as `norm`. + Default:None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape + of (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + """ + + def __init__(self, + attn_cfgs=None, + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + operation_order=None, + norm_cfg=dict(type='LN'), + init_cfg=None, + batch_first=False, + **kwargs): + + deprecated_args = dict( + feedforward_channels='feedforward_channels', + ffn_dropout='ffn_drop', + ffn_num_fcs='num_fcs') + for ori_name, new_name in deprecated_args.items(): + if ori_name in kwargs: + warnings.warn( + f'The arguments `{ori_name}` in BaseTransformerLayer ' + f'has been deprecated, now you should set `{new_name}` ' + f'and other FFN related arguments ' + f'to a dict named `ffn_cfgs`. ') + ffn_cfgs[new_name] = kwargs[ori_name] + + super(BaseTransformerLayer, self).__init__(init_cfg) + + self.batch_first = batch_first + + assert set(operation_order) & set( + ['self_attn', 'norm', 'ffn', 'cross_attn']) == \ + set(operation_order), f'The operation_order of' \ + f' {self.__class__.__name__} should ' \ + f'contains all four operation type ' \ + f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" + + num_attn = operation_order.count('self_attn') + operation_order.count( + 'cross_attn') + if isinstance(attn_cfgs, dict): + attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] + else: + assert num_attn == len(attn_cfgs), f'The length ' \ + f'of attn_cfg {num_attn} is ' \ + f'not consistent with the number of attention' \ + f'in operation_order {operation_order}.' + + self.num_attn = num_attn + self.operation_order = operation_order + self.norm_cfg = norm_cfg + self.pre_norm = operation_order[0] == 'norm' + self.attentions = ModuleList() + + index = 0 + for operation_name in operation_order: + if operation_name in ['self_attn', 'cross_attn']: + if 'batch_first' in attn_cfgs[index]: + assert self.batch_first == attn_cfgs[index]['batch_first'] + else: + attn_cfgs[index]['batch_first'] = self.batch_first + attention = build_attention(attn_cfgs[index]) + # Some custom attentions used as `self_attn` + # or `cross_attn` can have different behavior. + attention.operation_name = operation_name + self.attentions.append(attention) + index += 1 + + self.embed_dims = self.attentions[0].embed_dims + + self.ffns = ModuleList() + num_ffns = operation_order.count('ffn') + if isinstance(ffn_cfgs, dict): + ffn_cfgs = ConfigDict(ffn_cfgs) + if isinstance(ffn_cfgs, dict): + ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] + assert len(ffn_cfgs) == num_ffns + for ffn_index in range(num_ffns): + if 'embed_dims' not in ffn_cfgs[ffn_index]: + ffn_cfgs['embed_dims'] = self.embed_dims + else: + assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims + self.ffns.append( + build_feedforward_network(ffn_cfgs[ffn_index], + dict(type='FFN'))) + + self.norms = ModuleList() + num_norms = operation_order.count('norm') + for _ in range(num_norms): + self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) + + def forward(self, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `TransformerDecoderLayer`. + + **kwargs contains some specific arguments of attentions. + + Args: + query (Tensor): The input query with shape + [num_queries, bs, embed_dims] if + self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + value (Tensor): The value tensor with same shape as `key`. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. + Default: None. + attn_masks (List[Tensor] | None): 2D Tensor used in + calculation of corresponding attention. The length of + it should equal to the number of `attention` in + `operation_order`. Default: None. + query_key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_queries]. Only used in `self_attn` layer. + Defaults to None. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_keys]. Default: None. + + Returns: + Tensor: forwarded results with shape [num_queries, bs, embed_dims]. + """ + + norm_index = 0 + attn_index = 0 + ffn_index = 0 + identity = query + if attn_masks is None: + attn_masks = [None for _ in range(self.num_attn)] + elif isinstance(attn_masks, torch.Tensor): + attn_masks = [ + copy.deepcopy(attn_masks) for _ in range(self.num_attn) + ] + warnings.warn(f'Use same attn_mask in all attentions in ' + f'{self.__class__.__name__} ') + else: + assert len(attn_masks) == self.num_attn, f'The length of ' \ + f'attn_masks {len(attn_masks)} must be equal ' \ + f'to the number of attention in ' \ + f'operation_order {self.num_attn}' + + for layer in self.operation_order: + if layer == 'self_attn': + temp_key = temp_value = query + query = self.attentions[attn_index]( + query, + temp_key, + temp_value, + identity if self.pre_norm else None, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=attn_masks[attn_index], + key_padding_mask=query_key_padding_mask, + **kwargs) + attn_index += 1 + identity = query + + elif layer == 'norm': + query = self.norms[norm_index](query) + norm_index += 1 + + elif layer == 'cross_attn': + query = self.attentions[attn_index]( + query, + key, + value, + identity if self.pre_norm else None, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=attn_masks[attn_index], + key_padding_mask=key_padding_mask, + **kwargs) + attn_index += 1 + identity = query + + elif layer == 'ffn': + query = self.ffns[ffn_index]( + query, identity if self.pre_norm else None) + ffn_index += 1 + + return query + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class TransformerLayerSequence(BaseModule): + """Base class for TransformerEncoder and TransformerDecoder in vision + transformer. + + As base-class of Encoder and Decoder in vision transformer. + Support customization such as specifying different kind + of `transformer_layer` in `transformer_coder`. + + Args: + transformerlayer (list[obj:`mmcv.ConfigDict`] | + obj:`mmcv.ConfigDict`): Config of transformerlayer + in TransformerCoder. If it is obj:`mmcv.ConfigDict`, + it would be repeated `num_layer` times to a + list[`mmcv.ConfigDict`]. Default: None. + num_layers (int): The number of `TransformerLayer`. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): + super(TransformerLayerSequence, self).__init__(init_cfg) + if isinstance(transformerlayers, dict): + transformerlayers = [ + copy.deepcopy(transformerlayers) for _ in range(num_layers) + ] + else: + assert isinstance(transformerlayers, list) and \ + len(transformerlayers) == num_layers + self.num_layers = num_layers + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append(build_transformer_layer(transformerlayers[i])) + self.embed_dims = self.layers[0].embed_dims + self.pre_norm = self.layers[0].pre_norm + + def forward(self, + query, + key, + value, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `TransformerCoder`. + + Args: + query (Tensor): Input query with shape + `(num_queries, bs, embed_dims)`. + key (Tensor): The key tensor with shape + `(num_keys, bs, embed_dims)`. + value (Tensor): The value tensor with shape + `(num_keys, bs, embed_dims)`. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. + Default: None. + attn_masks (List[Tensor], optional): Each element is 2D Tensor + which is used in calculation of corresponding attention in + operation_order. Default: None. + query_key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_queries]. Only used in self-attention + Default: None. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_keys]. Default: None. + + Returns: + Tensor: results with shape [num_queries, bs, embed_dims]. + """ + for layer in self.layers: + query = layer( + query, + key, + value, + query_pos=query_pos, + key_pos=key_pos, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask, + **kwargs) + return query diff --git a/annotator/uniformer/mmcv/cnn/bricks/upsample.py b/annotator/uniformer/mmcv/cnn/bricks/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/upsample.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import xavier_init +from .registry import UPSAMPLE_LAYERS + +UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample) +UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample) + + +@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle') +class PixelShufflePack(nn.Module): + """Pixel Shuffle upsample layer. + + This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to + achieve a simple upsampling with pixel shuffle. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Upsample ratio. + upsample_kernel (int): Kernel size of the conv layer to expand the + channels. + """ + + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super(PixelShufflePack, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2d( + self.in_channels, + self.out_channels * scale_factor * scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.init_weights() + + def init_weights(self): + xavier_init(self.upsample_conv, distribution='uniform') + + def forward(self, x): + x = self.upsample_conv(x) + x = F.pixel_shuffle(x, self.scale_factor) + return x + + +def build_upsample_layer(cfg, *args, **kwargs): + """Build upsample layer. + + Args: + cfg (dict): The upsample layer config, which should contain: + + - type (str): Layer type. + - scale_factor (int): Upsample ratio, which is not applicable to + deconv. + - layer args: Args needed to instantiate a upsample layer. + args (argument list): Arguments passed to the ``__init__`` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the + ``__init__`` method of the corresponding conv layer. + + Returns: + nn.Module: Created upsample layer. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + raise KeyError( + f'the cfg dict must contain the key "type", but got {cfg}') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in UPSAMPLE_LAYERS: + raise KeyError(f'Unrecognized upsample type {layer_type}') + else: + upsample = UPSAMPLE_LAYERS.get(layer_type) + + if upsample is nn.Upsample: + cfg_['mode'] = layer_type + layer = upsample(*args, **kwargs, **cfg_) + return layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/wrappers.py b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501 + +Wrap some nn modules to support empty tensor input. Currently, these wrappers +are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask +heads are trained on only positive RoIs. +""" +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair, _triple + +from .registry import CONV_LAYERS, UPSAMPLE_LAYERS + +if torch.__version__ == 'parrots': + TORCH_VERSION = torch.__version__ +else: + # torch.__version__ could be 1.3.1+cu92, we only need the first two + # for comparison + TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) + + +def obsolete_torch_version(torch_version, version_threshold): + return torch_version == 'parrots' or torch_version <= version_threshold + + +class NewEmptyTensorOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return NewEmptyTensorOp.apply(grad, shape), None + + +@CONV_LAYERS.register_module('Conv', force=True) +class Conv2d(nn.Conv2d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, + self.padding, self.stride, self.dilation): + o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module('Conv3d', force=True) +class Conv3d(nn.Conv3d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, + self.padding, self.stride, self.dilation): + o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module() +@CONV_LAYERS.register_module('deconv') +@UPSAMPLE_LAYERS.register_module('deconv', force=True) +class ConvTranspose2d(nn.ConvTranspose2d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, + self.padding, self.stride, + self.dilation, self.output_padding): + out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module() +@CONV_LAYERS.register_module('deconv3d') +@UPSAMPLE_LAYERS.register_module('deconv3d', force=True) +class ConvTranspose3d(nn.ConvTranspose3d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, + self.padding, self.stride, + self.dilation, self.output_padding): + out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +class MaxPool2d(nn.MaxPool2d): + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + out_shape = list(x.shape[:2]) + for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), + _pair(self.padding), _pair(self.stride), + _pair(self.dilation)): + o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 + o = math.ceil(o) if self.ceil_mode else math.floor(o) + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + return empty + + return super().forward(x) + + +class MaxPool3d(nn.MaxPool3d): + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + out_shape = list(x.shape[:2]) + for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), + _triple(self.padding), + _triple(self.stride), + _triple(self.dilation)): + o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 + o = math.ceil(o) if self.ceil_mode else math.floor(o) + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + return empty + + return super().forward(x) + + +class Linear(torch.nn.Linear): + + def forward(self, x): + # empty tensor forward of Linear layer is supported in Pytorch 1.6 + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): + out_shape = [x.shape[0], self.out_features] + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) diff --git a/annotator/uniformer/mmcv/cnn/builder.py b/annotator/uniformer/mmcv/cnn/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/builder.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..runner import Sequential +from ..utils import Registry, build_from_cfg + + +def build_model_from_cfg(cfg, registry, default_args=None): + """Build a PyTorch model from config dict(s). Different from + ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. + + Args: + cfg (dict, list[dict]): The config of modules, is is either a config + dict or a list of config dicts. If cfg is a list, a + the built modules will be wrapped with ``nn.Sequential``. + registry (:obj:`Registry`): A registry the module belongs to. + default_args (dict, optional): Default arguments to build the module. + Defaults to None. + + Returns: + nn.Module: A built nn module. + """ + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return Sequential(*modules) + else: + return build_from_cfg(cfg, registry, default_args) + + +MODELS = Registry('model', build_func=build_model_from_cfg) diff --git a/annotator/uniformer/mmcv/cnn/resnet.py b/annotator/uniformer/mmcv/cnn/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/resnet.py @@ -0,0 +1,316 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.nn as nn +import torch.utils.checkpoint as cp + +from .utils import constant_init, kaiming_init + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False): + super(BasicBlock, self).__init__() + assert style in ['pytorch', 'caffe'] + self.conv1 = conv3x3(inplanes, planes, stride, dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + assert not with_cp + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False): + """Bottleneck block. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__() + assert style in ['pytorch', 'caffe'] + if style == 'pytorch': + conv1_stride = 1 + conv2_stride = stride + else: + conv1_stride = stride + conv2_stride = 1 + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def make_res_layer(block, + inplanes, + planes, + blocks, + stride=1, + dilation=1, + style='pytorch', + with_cp=False): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + inplanes, + planes, + stride, + dilation, + downsample, + style=style, + with_cp=with_cp)) + inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp)) + + return nn.Sequential(*layers) + + +class ResNet(nn.Module): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + frozen_stages=-1, + bn_eval=True, + bn_frozen=False, + with_cp=False): + super(ResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + assert num_stages >= 1 and num_stages <= 4 + block, stage_blocks = self.arch_settings[depth] + stage_blocks = stage_blocks[:num_stages] + assert len(strides) == len(dilations) == num_stages + assert max(out_indices) < num_stages + + self.out_indices = out_indices + self.style = style + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + self.with_cp = with_cp + + self.inplanes = 64 + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.res_layers = [] + for i, num_blocks in enumerate(stage_blocks): + stride = strides[i] + dilation = dilations[i] + planes = 64 * 2**i + res_layer = make_res_layer( + block, + self.inplanes, + planes, + num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + with_cp=with_cp) + self.inplanes = planes * block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(ResNet, self).train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + if mode and self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for param in self.bn1.parameters(): + param.requires_grad = False + self.bn1.eval() + self.bn1.weight.requires_grad = False + self.bn1.bias.requires_grad = False + for i in range(1, self.frozen_stages + 1): + mod = getattr(self, f'layer{i}') + mod.eval() + for param in mod.parameters(): + param.requires_grad = False diff --git a/annotator/uniformer/mmcv/cnn/utils/__init__.py b/annotator/uniformer/mmcv/cnn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .flops_counter import get_model_complexity_info +from .fuse_conv_bn import fuse_conv_bn +from .sync_bn import revert_sync_batchnorm +from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit, + KaimingInit, NormalInit, PretrainedInit, + TruncNormalInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, + constant_init, initialize, kaiming_init, normal_init, + trunc_normal_init, uniform_init, xavier_init) + +__all__ = [ + 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', + 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize', + 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit', 'revert_sync_batchnorm' +] diff --git a/annotator/uniformer/mmcv/cnn/utils/flops_counter.py b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..d10af5feca7f4b8c0ba359b7b1c826f754e048be --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py @@ -0,0 +1,599 @@ +# Modified from flops-counter.pytorch by Vladislav Sovrasov +# original repo: https://github.com/sovrasov/flops-counter.pytorch + +# MIT License + +# Copyright (c) 2018 Vladislav Sovrasov + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +import annotator.uniformer.mmcv as mmcv + + +def get_model_complexity_info(model, + input_shape, + print_per_layer_stat=True, + as_strings=True, + input_constructor=None, + flush=False, + ost=sys.stdout): + """Get complexity information of a model. + + This method can calculate FLOPs and parameter counts of a model with + corresponding input shape. It can also print complexity information for + each layer in a model. + + Supported layers are listed as below: + - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. + - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, + ``nn.ReLU6``. + - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, + ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, + ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, + ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, + ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. + - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, + ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, + ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. + - Linear: ``nn.Linear``. + - Deconvolution: ``nn.ConvTranspose2d``. + - Upsample: ``nn.Upsample``. + + Args: + model (nn.Module): The model for complexity calculation. + input_shape (tuple): Input shape used for calculation. + print_per_layer_stat (bool): Whether to print complexity information + for each layer in a model. Default: True. + as_strings (bool): Output FLOPs and params counts in a string form. + Default: True. + input_constructor (None | callable): If specified, it takes a callable + method that generates input. otherwise, it will generate a random + tensor with input shape to calculate FLOPs. Default: None. + flush (bool): same as that in :func:`print`. Default: False. + ost (stream): same as ``file`` param in :func:`print`. + Default: sys.stdout. + + Returns: + tuple[float | str]: If ``as_strings`` is set to True, it will return + FLOPs and parameter counts in a string format. otherwise, it will + return those in a float number format. + """ + assert type(input_shape) is tuple + assert len(input_shape) >= 1 + assert isinstance(model, nn.Module) + flops_model = add_flops_counting_methods(model) + flops_model.eval() + flops_model.start_flops_count() + if input_constructor: + input = input_constructor(input_shape) + _ = flops_model(**input) + else: + try: + batch = torch.ones(()).new_empty( + (1, *input_shape), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + except StopIteration: + # Avoid StopIteration for models which have no parameters, + # like `nn.Relu()`, `nn.AvgPool2d`, etc. + batch = torch.ones(()).new_empty((1, *input_shape)) + + _ = flops_model(batch) + + flops_count, params_count = flops_model.compute_average_flops_cost() + if print_per_layer_stat: + print_model_with_flops( + flops_model, flops_count, params_count, ost=ost, flush=flush) + flops_model.stop_flops_count() + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops, units='GFLOPs', precision=2): + """Convert FLOPs number into a string. + + Note that Here we take a multiply-add counts as one FLOP. + + Args: + flops (float): FLOPs number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'GFLOPs', + 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically + choose the most suitable unit for FLOPs. Default: 'GFLOPs'. + precision (int): Digit number after the decimal point. Default: 2. + + Returns: + str: The converted FLOPs number with units. + + Examples: + >>> flops_to_string(1e9) + '1.0 GFLOPs' + >>> flops_to_string(2e5, 'MFLOPs') + '0.2 MFLOPs' + >>> flops_to_string(3e-9, None) + '3e-09 FLOPs' + """ + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GFLOPs' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MFLOPs' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KFLOPs' + else: + return str(flops) + ' FLOPs' + else: + if units == 'GFLOPs': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MFLOPs': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KFLOPs': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' FLOPs' + + +def params_to_string(num_params, units=None, precision=2): + """Convert parameter number into a string. + + Args: + num_params (float): Parameter number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'M', + 'K' and ''. If set to None, it will automatically choose the most + suitable unit for Parameter number. Default: None. + precision (int): Digit number after the decimal point. Default: 2. + + Returns: + str: The converted parameter number with units. + + Examples: + >>> params_to_string(1e9) + '1000.0 M' + >>> params_to_string(2e5) + '200.0 k' + >>> params_to_string(3e-9) + '3e-09' + """ + if units is None: + if num_params // 10**6 > 0: + return str(round(num_params / 10**6, precision)) + ' M' + elif num_params // 10**3: + return str(round(num_params / 10**3, precision)) + ' k' + else: + return str(num_params) + else: + if units == 'M': + return str(round(num_params / 10.**6, precision)) + ' ' + units + elif units == 'K': + return str(round(num_params / 10.**3, precision)) + ' ' + units + else: + return str(num_params) + + +def print_model_with_flops(model, + total_flops, + total_params, + units='GFLOPs', + precision=3, + ost=sys.stdout, + flush=False): + """Print a model with FLOPs for each layer. + + Args: + model (nn.Module): The model to be printed. + total_flops (float): Total FLOPs of the model. + total_params (float): Total parameter counts of the model. + units (str | None): Converted FLOPs units. Default: 'GFLOPs'. + precision (int): Digit number after the decimal point. Default: 3. + ost (stream): same as `file` param in :func:`print`. + Default: sys.stdout. + flush (bool): same as that in :func:`print`. Default: False. + + Example: + >>> class ExampleModel(nn.Module): + + >>> def __init__(self): + >>> super().__init__() + >>> self.conv1 = nn.Conv2d(3, 8, 3) + >>> self.conv2 = nn.Conv2d(8, 256, 3) + >>> self.conv3 = nn.Conv2d(256, 8, 3) + >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Linear(8, 1) + + >>> def forward(self, x): + >>> x = self.conv1(x) + >>> x = self.conv2(x) + >>> x = self.conv3(x) + >>> x = self.avg_pool(x) + >>> x = self.flatten(x) + >>> x = self.fc(x) + >>> return x + + >>> model = ExampleModel() + >>> x = (3, 16, 16) + to print the complexity information state for each layer, you can use + >>> get_model_complexity_info(model, x) + or directly use + >>> print_model_with_flops(model, 4579784.0, 37361) + ExampleModel( + 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, + (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501 + (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1)) + (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1)) + (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1)) + (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, ) + (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True) + ) + """ + + def accumulate_params(self): + if is_supported_instance(self): + return self.__params__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_params() + return sum + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_num_params = self.accumulate_params() + accumulated_flops_cost = self.accumulate_flops() + return ', '.join([ + params_to_string( + accumulated_num_params, units='M', precision=precision), + '{:.3%} Params'.format(accumulated_num_params / total_params), + flops_to_string( + accumulated_flops_cost, units=units, precision=precision), + '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr() + ]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + m.accumulate_params = accumulate_params.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model, file=ost, flush=flush) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model): + """Calculate parameter number of a model. + + Args: + model (nn.module): The model for parameter number calculation. + + Returns: + float: Parameter number of the model. + """ + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return num_params + + +def add_flops_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_count = start_flops_count.__get__( + net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__( + net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__( + net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501 + net_main_module) + + net_main_module.reset_flops_count() + + return net_main_module + + +def compute_average_flops_cost(self): + """Compute average FLOPs cost. + + A method to compute average FLOPs cost, which will be available after + `add_flops_counting_methods()` is called on a desired net object. + + Returns: + float: Current mean flops consumption per image. + """ + batches_count = self.__batch_counter__ + flops_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + params_sum = get_model_parameters_number(self) + return flops_sum / batches_count, params_sum + + +def start_flops_count(self): + """Activate the computation of mean flops consumption per image. + + A method to activate the computation of mean flops consumption per image. + which will be available after ``add_flops_counting_methods()`` is called on + a desired net object. It should be called before running the network. + """ + add_batch_counter_hook_function(self) + + def add_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + else: + handle = module.register_forward_hook( + get_modules_mapping()[type(module)]) + + module.__flops_handle__ = handle + + self.apply(partial(add_flops_counter_hook_function)) + + +def stop_flops_count(self): + """Stop computing the mean flops consumption per image. + + A method to stop computing the mean flops consumption per image, which will + be available after ``add_flops_counting_methods()`` is called on a desired + net object. It can be called to pause the computation whenever. + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """Reset statistics computed so far. + + A method to Reset computed statistics, which will be available after + `add_flops_counting_methods()` is called on a desired net object. + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_counter_variable_or_reset) + + +# ---- Internal functions +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def relu_flops_counter_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + output_last_dim = output.shape[ + -1] # pytorch checks dimensions, so here we don't care much + module.__flops__ += int(np.prod(input.shape) * output_last_dim) + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + + +def norm_flops_counter_hook(module, input, output): + input = input[0] + + batch_flops = np.prod(input.shape) + if (getattr(module, 'affine', False) + or getattr(module, 'elementwise_affine', False)): + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +def deconv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + input_height, input_width = input.shape[2:] + + kernel_height, kernel_width = conv_module.kernel_size + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = ( + kernel_height * kernel_width * in_channels * filters_per_channel) + + active_elements_count = batch_size * input_height * input_width + overall_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if conv_module.bias is not None: + output_height, output_width = output.shape[2:] + bias_flops = out_channels * batch_size * output_height * output_height + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + + bias_flops = 0 + + if conv_module.bias is not None: + + bias_flops = out_channels * active_elements_count + + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def batch_counter_hook(module, input, output): + batch_size = 1 + if len(input) > 0: + # Can have multiple inputs, getting the first one + input = input[0] + batch_size = len(input) + else: + pass + print('Warning! No positional inputs found for a module, ' + 'assuming batch size is 1.') + module.__batch_counter__ += batch_size + + +def add_batch_counter_variables_or_reset(module): + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def remove_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + if hasattr(module, '__flops__') or hasattr(module, '__params__'): + print('Warning: variables __flops__ or __params__ are already ' + 'defined for the module' + type(module).__name__ + + ' ptflops can affect your code!') + module.__flops__ = 0 + module.__params__ = get_model_parameters_number(module) + + +def is_supported_instance(module): + if type(module) in get_modules_mapping(): + return True + return False + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ + + +def get_modules_mapping(): + return { + # convolutions + nn.Conv1d: conv_flops_counter_hook, + nn.Conv2d: conv_flops_counter_hook, + mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook, + nn.Conv3d: conv_flops_counter_hook, + mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook, + # activations + nn.ReLU: relu_flops_counter_hook, + nn.PReLU: relu_flops_counter_hook, + nn.ELU: relu_flops_counter_hook, + nn.LeakyReLU: relu_flops_counter_hook, + nn.ReLU6: relu_flops_counter_hook, + # poolings + nn.MaxPool1d: pool_flops_counter_hook, + nn.AvgPool1d: pool_flops_counter_hook, + nn.AvgPool2d: pool_flops_counter_hook, + nn.MaxPool2d: pool_flops_counter_hook, + mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook, + nn.MaxPool3d: pool_flops_counter_hook, + mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook, + nn.AvgPool3d: pool_flops_counter_hook, + nn.AdaptiveMaxPool1d: pool_flops_counter_hook, + nn.AdaptiveAvgPool1d: pool_flops_counter_hook, + nn.AdaptiveMaxPool2d: pool_flops_counter_hook, + nn.AdaptiveAvgPool2d: pool_flops_counter_hook, + nn.AdaptiveMaxPool3d: pool_flops_counter_hook, + nn.AdaptiveAvgPool3d: pool_flops_counter_hook, + # normalizations + nn.BatchNorm1d: norm_flops_counter_hook, + nn.BatchNorm2d: norm_flops_counter_hook, + nn.BatchNorm3d: norm_flops_counter_hook, + nn.GroupNorm: norm_flops_counter_hook, + nn.InstanceNorm1d: norm_flops_counter_hook, + nn.InstanceNorm2d: norm_flops_counter_hook, + nn.InstanceNorm3d: norm_flops_counter_hook, + nn.LayerNorm: norm_flops_counter_hook, + # FC + nn.Linear: linear_flops_counter_hook, + mmcv.cnn.bricks.Linear: linear_flops_counter_hook, + # Upscale + nn.Upsample: upsample_flops_counter_hook, + # Deconvolution + nn.ConvTranspose2d: deconv_flops_counter_hook, + mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook, + } diff --git a/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def _fuse_conv_bn(conv, bn): + """Fuse conv and bn into one module. + + Args: + conv (nn.Module): Conv to be fused. + bn (nn.Module): BN to be fused. + + Returns: + nn.Module: Fused module. + """ + conv_w = conv.weight + conv_b = conv.bias if conv.bias is not None else torch.zeros_like( + bn.running_mean) + + factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) + conv.weight = nn.Parameter(conv_w * + factor.reshape([conv.out_channels, 1, 1, 1])) + conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) + return conv + + +def fuse_conv_bn(module): + """Recursively fuse conv and bn in a module. + + During inference, the functionary of batch norm layers is turned off + but only the mean and var alone channels are used, which exposes the + chance to fuse it with the preceding conv layers to save computations and + simplify network structures. + + Args: + module (nn.Module): Module to be fused. + + Returns: + nn.Module: Fused module. + """ + last_conv = None + last_conv_name = None + + for name, child in module.named_children(): + if isinstance(child, + (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): + if last_conv is None: # only fuse BN that is after Conv + continue + fused_conv = _fuse_conv_bn(last_conv, child) + module._modules[last_conv_name] = fused_conv + # To reduce changes, set BN as Identity instead of deleting it. + module._modules[name] = nn.Identity() + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + fuse_conv_bn(child) + return module diff --git a/annotator/uniformer/mmcv/cnn/utils/sync_bn.py b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..f78f39181d75bb85c53e8c7c8eaf45690e9f0bee --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py @@ -0,0 +1,59 @@ +import torch + +import annotator.uniformer.mmcv as mmcv + + +class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input): + return + + +def revert_sync_batchnorm(module): + """Helper function to convert all `SyncBatchNorm` (SyncBN) and + `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to + `BatchNormXd` layers. + + Adapted from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module (nn.Module): The module containing `SyncBatchNorm` layers. + + Returns: + module_output: The converted module with `BatchNormXd` layers. + """ + module_output = module + module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] + if hasattr(mmcv, 'ops'): + module_checklist.append(mmcv.ops.SyncBatchNorm) + if isinstance(module, tuple(module_checklist)): + module_output = _BatchNormXd(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + # no_grad() may not be needed here but + # just to be consistent with `convert_sync_batchnorm()` + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + # qconfig exists in quantized models + if hasattr(module, 'qconfig'): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/annotator/uniformer/mmcv/cnn/utils/weight_init.py b/annotator/uniformer/mmcv/cnn/utils/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..287a1d0bffe26e023029d48634d9b761deda7ba4 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/weight_init.py @@ -0,0 +1,684 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +import warnings + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from annotator.uniformer.mmcv.utils import Registry, build_from_cfg, get_logger, print_log + +INITIALIZERS = Registry('initializer') + + +def update_init_info(module, init_info): + """Update the `_params_init_info` in the module if the value of parameters + are changed. + + Args: + module (obj:`nn.Module`): The module of PyTorch with a user-defined + attribute `_params_init_info` which records the initialization + information. + init_info (str): The string that describes the initialization. + """ + assert hasattr( + module, + '_params_init_info'), f'Can not find `_params_init_info` in {module}' + for name, param in module.named_parameters(): + + assert param in module._params_init_info, ( + f'Find a new :obj:`Parameter` ' + f'named `{name}` during executing the ' + f'`init_weights` of ' + f'`{module.__class__.__name__}`. ' + f'Please do not add or ' + f'replace parameters during executing ' + f'the `init_weights`. ') + + # The parameter has been changed during executing the + # `init_weights` of module + mean_value = param.data.mean() + if module._params_init_info[param]['tmp_mean_value'] != mean_value: + module._params_init_info[param]['init_info'] = init_info + module._params_init_info[param]['tmp_mean_value'] = mean_value + + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def xavier_init(module, gain=1, bias=0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module, mean=0, std=1, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def trunc_normal_init(module: nn.Module, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + trunc_normal_(module.weight, mean, std, a, b) # type: ignore + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) # type: ignore + + +def uniform_init(module, a=0, b=1, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def kaiming_init(module, + a=0, + mode='fan_out', + nonlinearity='relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def caffe2_xavier_init(module, bias=0): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + kaiming_init( + module, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + bias=bias, + distribution='uniform') + + +def bias_init_with_prob(prior_prob): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +def _get_bases_name(m): + return [b.__name__ for b in m.__class__.__bases__] + + +class BaseInit(object): + + def __init__(self, *, bias=0, bias_prob=None, layer=None): + self.wholemodule = False + if not isinstance(bias, (int, float)): + raise TypeError(f'bias must be a number, but got a {type(bias)}') + + if bias_prob is not None: + if not isinstance(bias_prob, float): + raise TypeError(f'bias_prob type must be float, \ + but got {type(bias_prob)}') + + if layer is not None: + if not isinstance(layer, (str, list)): + raise TypeError(f'layer must be a str or a list of str, \ + but got a {type(layer)}') + else: + layer = [] + + if bias_prob is not None: + self.bias = bias_init_with_prob(bias_prob) + else: + self.bias = bias + self.layer = [layer] if isinstance(layer, str) else layer + + def _get_init_info(self): + info = f'{self.__class__.__name__}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Constant') +class ConstantInit(BaseInit): + """Initialize module parameters with constant values. + + Args: + val (int | float): the value to fill the weights in the module with + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, val, **kwargs): + super().__init__(**kwargs) + self.val = val + + def __call__(self, module): + + def init(m): + if self.wholemodule: + constant_init(m, self.val, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + constant_init(m, self.val, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Xavier') +class XavierInit(BaseInit): + r"""Initialize module parameters with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks - Glorot, X. & Bengio, Y. (2010). + `_ + + Args: + gain (int | float): an optional scaling factor. Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` + or ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, gain=1, distribution='normal', **kwargs): + super().__init__(**kwargs) + self.gain = gain + self.distribution = distribution + + def __call__(self, module): + + def init(m): + if self.wholemodule: + xavier_init(m, self.gain, self.bias, self.distribution) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + xavier_init(m, self.gain, self.bias, self.distribution) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: gain={self.gain}, ' \ + f'distribution={self.distribution}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Normal') +class NormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + mean (int | float):the mean of the normal distribution. Defaults to 0. + std (int | float): the standard deviation of the normal distribution. + Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, mean=0, std=1, **kwargs): + super().__init__(**kwargs) + self.mean = mean + self.std = std + + def __call__(self, module): + + def init(m): + if self.wholemodule: + normal_init(m, self.mean, self.std, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + normal_init(m, self.mean, self.std, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: mean={self.mean},' \ + f' std={self.std}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='TruncNormal') +class TruncNormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values + outside :math:`[a, b]`. + + Args: + mean (float): the mean of the normal distribution. Defaults to 0. + std (float): the standard deviation of the normal distribution. + Defaults to 1. + a (float): The minimum cutoff value. + b ( float): The maximum cutoff value. + bias (float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + **kwargs) -> None: + super().__init__(**kwargs) + self.mean = mean + self.std = std + self.a = a + self.b = b + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \ + f' mean={self.mean}, std={self.std}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Uniform') +class UniformInit(BaseInit): + r"""Initialize module parameters with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + + Args: + a (int | float): the lower bound of the uniform distribution. + Defaults to 0. + b (int | float): the upper bound of the uniform distribution. + Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, a=0, b=1, **kwargs): + super().__init__(**kwargs) + self.a = a + self.b = b + + def __call__(self, module): + + def init(m): + if self.wholemodule: + uniform_init(m, self.a, self.b, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + uniform_init(m, self.a, self.b, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a},' \ + f' b={self.b}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Kaiming') +class KaimingInit(BaseInit): + r"""Initialize module parameters with the values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification - He, K. et al. (2015). + `_ + + Args: + a (int | float): the negative slope of the rectifier used after this + layer (only used with ``'leaky_relu'``). Defaults to 0. + mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing + ``'fan_in'`` preserves the magnitude of the variance of the weights + in the forward pass. Choosing ``'fan_out'`` preserves the + magnitudes in the backwards pass. Defaults to ``'fan_out'``. + nonlinearity (str): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` . + Defaults to 'relu'. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` or + ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, + a=0, + mode='fan_out', + nonlinearity='relu', + distribution='normal', + **kwargs): + super().__init__(**kwargs) + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + self.distribution = distribution + + def __call__(self, module): + + def init(m): + if self.wholemodule: + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \ + f'nonlinearity={self.nonlinearity}, ' \ + f'distribution ={self.distribution}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Caffe2Xavier') +class Caffe2XavierInit(KaimingInit): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + def __init__(self, **kwargs): + super().__init__( + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform', + **kwargs) + + def __call__(self, module): + super().__call__(module) + + +@INITIALIZERS.register_module(name='Pretrained') +class PretrainedInit(object): + """Initialize module by loading a pretrained model. + + Args: + checkpoint (str): the checkpoint file of the pretrained model should + be load. + prefix (str, optional): the prefix of a sub-module in the pretrained + model. it is for loading a part of the pretrained model to + initialize. For example, if we would like to only load the + backbone of a detector model, we can set ``prefix='backbone.'``. + Defaults to None. + map_location (str): map tensors into proper locations. + """ + + def __init__(self, checkpoint, prefix=None, map_location=None): + self.checkpoint = checkpoint + self.prefix = prefix + self.map_location = map_location + + def __call__(self, module): + from annotator.uniformer.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, + load_state_dict) + logger = get_logger('mmcv') + if self.prefix is None: + print_log(f'load model from: {self.checkpoint}', logger=logger) + load_checkpoint( + module, + self.checkpoint, + map_location=self.map_location, + strict=False, + logger=logger) + else: + print_log( + f'load {self.prefix} in model from: {self.checkpoint}', + logger=logger) + state_dict = _load_checkpoint_with_prefix( + self.prefix, self.checkpoint, map_location=self.map_location) + load_state_dict(module, state_dict, strict=False, logger=logger) + + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: load from {self.checkpoint}' + return info + + +def _initialize(module, cfg, wholemodule=False): + func = build_from_cfg(cfg, INITIALIZERS) + # wholemodule flag is for override mode, there is no layer key in override + # and initializer will give init values for the whole module with the name + # in override. + func.wholemodule = wholemodule + func(module) + + +def _initialize_override(module, override, cfg): + if not isinstance(override, (dict, list)): + raise TypeError(f'override must be a dict or a list of dict, \ + but got {type(override)}') + + override = [override] if isinstance(override, dict) else override + + for override_ in override: + + cp_override = copy.deepcopy(override_) + name = cp_override.pop('name', None) + if name is None: + raise ValueError('`override` must contain the key "name",' + f'but got {cp_override}') + # if override only has name key, it means use args in init_cfg + if not cp_override: + cp_override.update(cfg) + # if override has name key and other args except type key, it will + # raise error + elif 'type' not in cp_override.keys(): + raise ValueError( + f'`override` need "type" key, but got {cp_override}') + + if hasattr(module, name): + _initialize(getattr(module, name), cp_override, wholemodule=True) + else: + raise RuntimeError(f'module did not have attribute {name}, ' + f'but init_cfg is {cp_override}.') + + +def initialize(module, init_cfg): + """Initialize a module. + + Args: + module (``torch.nn.Module``): the module will be initialized. + init_cfg (dict | list[dict]): initialization configuration dict to + define initializer. OpenMMLab has implemented 6 initializers + including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, + ``Kaiming``, and ``Pretrained``. + Example: + >>> module = nn.Linear(2, 3, bias=True) + >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) + >>> initialize(module, init_cfg) + + >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) + >>> # define key ``'layer'`` for initializing layer with different + >>> # configuration + >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), + dict(type='Constant', layer='Linear', val=2)] + >>> initialize(module, init_cfg) + + >>> # define key``'override'`` to initialize some specific part in + >>> # module + >>> class FooNet(nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.feat = nn.Conv2d(3, 16, 3) + >>> self.reg = nn.Conv2d(16, 10, 3) + >>> self.cls = nn.Conv2d(16, 5, 3) + >>> model = FooNet() + >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', + >>> override=dict(type='Constant', name='reg', val=3, bias=4)) + >>> initialize(model, init_cfg) + + >>> model = ResNet(depth=50) + >>> # Initialize weights with the pretrained model. + >>> init_cfg = dict(type='Pretrained', + checkpoint='torchvision://resnet50') + >>> initialize(model, init_cfg) + + >>> # Initialize weights of a sub-module with the specific part of + >>> # a pretrained model by using "prefix". + >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\ + >>> 'retinanet_r50_fpn_1x_coco/'\ + >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' + >>> init_cfg = dict(type='Pretrained', + checkpoint=url, prefix='backbone.') + """ + if not isinstance(init_cfg, (dict, list)): + raise TypeError(f'init_cfg must be a dict or a list of dict, \ + but got {type(init_cfg)}') + + if isinstance(init_cfg, dict): + init_cfg = [init_cfg] + + for cfg in init_cfg: + # should deeply copy the original config because cfg may be used by + # other modules, e.g., one init_cfg shared by multiple bottleneck + # blocks, the expected cfg will be changed after pop and will change + # the initialization behavior of other modules + cp_cfg = copy.deepcopy(cfg) + override = cp_cfg.pop('override', None) + _initialize(module, cp_cfg) + + if override is not None: + cp_cfg.pop('layer', None) + _initialize_override(module, override, cp_cfg) + else: + # All attributes in module have same initialization. + pass + + +def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, + b: float) -> Tensor: + # Method based on + # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + # Modified from + # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [lower, upper], then translate + # to [2lower-1, 2upper-1]. + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor: Tensor, + mean: float = 0., + std: float = 1., + a: float = -2., + b: float = 2.) -> Tensor: + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Modified from + https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + + Args: + tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. + mean (float): the mean of the normal distribution. + std (float): the standard deviation of the normal distribution. + a (float): the minimum cutoff value. + b (float): the maximum cutoff value. + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/annotator/uniformer/mmcv/cnn/vgg.py b/annotator/uniformer/mmcv/cnn/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/vgg.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.nn as nn + +from .utils import constant_init, kaiming_init, normal_init + + +def conv3x3(in_planes, out_planes, dilation=1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + padding=dilation, + dilation=dilation) + + +def make_vgg_layer(inplanes, + planes, + num_blocks, + dilation=1, + with_bn=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layers.append(conv3x3(inplanes, planes, dilation)) + if with_bn: + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + inplanes = planes + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +class VGG(nn.Module): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_bn (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + """ + + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + with_bn=False, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + bn_eval=True, + bn_frozen=False, + ceil_mode=False, + with_last_pool=True): + super(VGG, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + assert max(out_indices) <= num_stages + + self.num_classes = num_classes + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + + self.inplanes = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks * (2 + with_bn) + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + planes = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.inplanes, + planes, + num_blocks, + dilation=dilation, + with_bn=with_bn, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.inplanes = planes + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(VGG, self).train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + vgg_layers = getattr(self, self.module_name) + if mode and self.frozen_stages >= 0: + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + mod = vgg_layers[j] + mod.eval() + for param in mod.parameters(): + param.requires_grad = False diff --git a/annotator/uniformer/mmcv/engine/__init__.py b/annotator/uniformer/mmcv/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082 --- /dev/null +++ b/annotator/uniformer/mmcv/engine/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test, + single_gpu_test) + +__all__ = [ + 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', + 'single_gpu_test' +] diff --git a/annotator/uniformer/mmcv/engine/test.py b/annotator/uniformer/mmcv/engine/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbeef271db634ec2dadfda3bc0b5ef9c7a677ff --- /dev/null +++ b/annotator/uniformer/mmcv/engine/test.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import torch +import torch.distributed as dist + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.runner import get_dist_info + + +def single_gpu_test(model, data_loader): + """Test model with a single gpu. + + This method tests model with a single gpu and displays test progress bar. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for data in data_loader: + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + # Assume result has the same length of batch_size + # refer to https://github.com/open-mmlab/mmcv/issues/985 + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting + ``gpu_collect=True``, it encodes results to gpu tensors and use gpu + communication for results collection. On cpu mode it saves the results on + different gpus to ``tmpdir`` and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + if rank == 0: + batch_size = len(result) + batch_size_all = batch_size * world_size + if batch_size_all + prog_bar.completed > len(dataset): + batch_size_all = len(dataset) - prog_bar.completed + for _ in range(batch_size_all): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results under cpu mode. + + On cpu mode, this function will save the results on different gpus to + ``tmpdir`` and collect them by the rank 0 worker. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + tmpdir (str | None): temporal directory for collected results to + store. If set to None, it will create a random temporal directory + for it. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + mmcv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_result = mmcv.load(part_file) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results under gpu mode. + + On gpu mode, this function will encode results to gpu tensors and use gpu + communication for results collection. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/annotator/uniformer/mmcv/fileio/__init__.py b/annotator/uniformer/mmcv/fileio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .file_client import BaseStorageBackend, FileClient +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +from .io import dump, load, register_handler +from .parse import dict_from_file, list_from_file + +__all__ = [ + 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler', + 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', + 'list_from_file', 'dict_from_file' +] diff --git a/annotator/uniformer/mmcv/fileio/file_client.py b/annotator/uniformer/mmcv/fileio/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..950f0c1aeab14b8e308a7455ccd64a95b5d98add --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/file_client.py @@ -0,0 +1,1148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import os +import os.path as osp +import re +import tempfile +import warnings +from abc import ABCMeta, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Iterable, Iterator, Optional, Tuple, Union +from urllib.request import urlopen + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.utils.misc import has_method +from annotator.uniformer.mmcv.utils.path import is_filepath + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + _allow_symlink = False + + @property + def name(self): + return self.__class__.__name__ + + @property + def allow_symlink(self): + return self._allow_symlink + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class CephBackend(BaseStorageBackend): + """Ceph storage backend (for internal use). + + Args: + path_mapping (dict|None): path mapping dict from local path to Petrel + path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath`` + will be replaced by ``dst``. Default: None. + + .. warning:: + :class:`mmcv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`mmcv.fileio.file_client.PetrelBackend` instead. + """ + + def __init__(self, path_mapping=None): + try: + import ceph + except ImportError: + raise ImportError('Please install ceph to enable CephBackend.') + + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') + self._client = ceph.S3Client() + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def get(self, filepath): + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + value = self._client.Get(filepath) + value_buf = memoryview(value) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class PetrelBackend(BaseStorageBackend): + """Petrel storage backend (for internal use). + + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. + + Args: + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Default: None. + enable_mc (bool, optional): Whether to enable memcached support. + Default: True. + + Examples: + >>> filepath1 = 's3://path/of/file' + >>> filepath2 = 'cluster-name:s3://path/of/file' + >>> client = PetrelBackend() + >>> client.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster + """ + + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True): + try: + from petrel_client import client + except ImportError: + raise ImportError('Please install petrel_client to enable ' + 'PetrelBackend.') + + self._client = client.Client(enable_mc=enable_mc) + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. + + Args: + filepath (str): Path to be mapped. + """ + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + return filepath + + def _format_path(self, filepath: str) -> str: + """Convert a ``filepath`` to standard format of petrel oss. + + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. + """ + return re.sub(r'\\+', '/', filepath) + + def get(self, filepath: Union[str, Path]) -> memoryview: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + memoryview: A memory view of expected bytes object to avoid + copying. The memoryview object can be converted to bytes by + ``value_buf.tobytes()``. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + value = self._client.Get(filepath) + value_buf = memoryview(value) + return value_buf + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Save data to a given ``filepath``. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.put(filepath, obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Save data to a given ``filepath``. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to encode the ``obj``. + Default: 'utf-8'. + """ + self.put(bytes(obj, encoding=encoding), filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + if not has_method(self._client, 'delete'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev' + ' branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.delete(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + if not has_method(self._client, 'contains'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result after concatenation. + """ + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] + for path in filepaths: + formatted_paths.append(self._format_path(self._map_path(path))) + return '/'.join(formatted_paths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + """Download a file from ``filepath`` and return a temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str | Path): Download a file from ``filepath``. + + Examples: + >>> client = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('s3://path/of/your/file') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one temporary path. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + assert self.isfile(filepath) + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if not has_method(self._client, 'list'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.')) + + dir_path = self._map_path(dir_path) + dir_path = self._format_path(dir_path) + if list_dir and suffix is not None: + raise TypeError( + '`list_dir` should be False when `suffix` is not None') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` + if path.endswith('/'): # a directory path + next_dir_path = self.join_path(dir_path, path) + if list_dir: + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) + else: # a file path + absolute_path = self.join_path(dir_path, path) + rel_path = absolute_path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_path (str): Lmdb database path. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_path (str): Lmdb database path. + """ + + def __init__(self, + db_path, + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + self.db_path = str(db_path) + self._client = lmdb.open( + self.db_path, + readonly=readonly, + lock=lock, + readahead=readahead, + **kwargs) + + def get(self, filepath): + """Get values according to the filepath. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + """ + filepath = str(filepath) + with self._client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + mmcv.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + mmcv.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + os.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: + """Only for unified API and do nothing.""" + yield filepath + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath): + value_buf = urlopen(filepath).read() + return value_buf + + def get_text(self, filepath, encoding='utf-8'): + value_buf = urlopen(filepath).read() + return value_buf.decode(encoding) + + @contextmanager + def get_local_path(self, filepath: str) -> Iterable[str]: + """Download a file from ``filepath``. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> client = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('http://path/of/your/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + +class FileClient: + """A general file client to access files in different backends. + + The client loads a file or text in a specified backend from its path + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object will be returned. + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Default: None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='petrel') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='petrel', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='petrel') + >>> file_client1 is file_client + True + + Attributes: + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'ceph': CephBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + 'petrel': PetrelBackend, + 'http': HTTPBackend, + } + # This collection is used to record the overridden backends, and when a + # backend appears in the collection, the singleton pattern is disabled for + # that backend, because if the singleton pattern is used, then the object + # returned will be the backend before overwriting + _overridden_backends = set() + _prefix_to_backends = { + 's3': PetrelBackend, + 'http': HTTPBackend, + 'https': HTTPBackend, + } + _overridden_prefixes = set() + + _instances = {} + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = 'disk' + if backend is not None and backend not in cls._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(cls._backends.keys())}') + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f'prefix {prefix} is not supported. Currently supported ones ' + f'are {list(cls._prefix_to_backends.keys())}') + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f'{backend}:{prefix}' + for key, value in kwargs.items(): + arg_key += f':{key}:{value}' + + # if a backend was overridden, it will create a new object + if (arg_key in cls._instances + and backend not in cls._overridden_backends + and prefix not in cls._overridden_prefixes): + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + + cls._instances[arg_key] = _instance + + return _instance + + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' + else ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if '://' not in uri: + return None + else: + prefix, _ = uri.split('://') + # In the case of PetrelBackend, the prefix may contains the cluster + # name like clusterName:s3 + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix + + @classmethod + def infer_client(cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None) -> 'FileClient': + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Default: None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Default: None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 'petrel'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): + if not isinstance(name, str): + raise TypeError('the backend name should be a string, ' + f'but got {type(name)}') + if not inspect.isclass(backend): + raise TypeError( + f'backend should be a class but got {type(backend)}') + if not issubclass(backend, BaseStorageBackend): + raise TypeError( + f'backend {backend} is not a subclass of BaseStorageBackend') + if not force and name in cls._backends: + raise KeyError( + f'{name} is already registered as a storage backend, ' + 'add "force=True" if you want to override it') + + if name in cls._backends and force: + cls._overridden_backends.add(name) + cls._backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + cls._overridden_prefixes.add(prefix) + cls._prefix_to_backends[prefix] = backend + else: + raise KeyError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') + + @classmethod + def register_backend(cls, name, backend=None, force=False, prefixes=None): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Default: None. + `New in version 1.3.15.` + """ + if backend is not None: + cls._register_backend( + name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + cls._register_backend( + name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ + return self.client.get(filepath) + + def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.join_path(filepath, *filepaths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, + suffix, recursive) diff --git a/annotator/uniformer/mmcv/fileio/handlers/__init__.py b/annotator/uniformer/mmcv/fileio/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseFileHandler +from .json_handler import JsonHandler +from .pickle_handler import PickleHandler +from .yaml_handler import YamlHandler + +__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler'] diff --git a/annotator/uniformer/mmcv/fileio/handlers/base.py b/annotator/uniformer/mmcv/fileio/handlers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/base.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + str_like = True + + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode='r', **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode='w', **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/handlers/json_handler.py b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + +import numpy as np + +from .base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f'{type(obj)} is unsupported for json dump') + + +class JsonHandler(BaseFileHandler): + + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('default', set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('default', set_default) + return json.dumps(obj, **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pickle + +from .base import BaseFileHandler + + +class PickleHandler(BaseFileHandler): + + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super(PickleHandler, self).load_from_path( + filepath, mode='rb', **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('protocol', 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('protocol', 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + super(PickleHandler, self).dump_to_path( + obj, filepath, mode='wb', **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import yaml + +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + +from .base import BaseFileHandler # isort:skip + + +class YamlHandler(BaseFileHandler): + + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault('Loader', Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('Dumper', Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('Dumper', Dumper) + return yaml.dump(obj, **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/io.py b/annotator/uniformer/mmcv/fileio/io.py new file mode 100644 index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/io.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO +from pathlib import Path + +from ..utils import is_list_of, is_str +from .file_client import FileClient +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler + +file_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), + 'pickle': PickleHandler(), + 'pkl': PickleHandler() +} + + +def load(file, file_format=None, file_client_args=None, **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Note: + In v1.3.16 and later, ``load`` supports loading data from serialized + files those can be storaged in different backends. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in petrel + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and is_str(file): + file_format = file.split('.')[-1] + if file_format not in file_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = file_handlers[file_format] + if is_str(file): + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO(file_client.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_client.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + elif hasattr(file, 'read'): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + Note: + In v1.3.16 and later, ``dump`` supports dumping data as strings or to + files which is saved to different backends. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if is_str(file): + file_format = file.split('.')[-1] + elif file is None: + raise ValueError( + 'file_format must be specified since file is None') + if file_format not in file_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif is_str(file): + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put(f.getvalue(), file) + elif hasattr(file, 'write'): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError( + f'handler must be a child of BaseFileHandler, not {type(handler)}') + if isinstance(file_formats, str): + file_formats = [file_formats] + if not is_list_of(file_formats, str): + raise TypeError('file_formats must be a str or a list of str') + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/annotator/uniformer/mmcv/fileio/parse.py b/annotator/uniformer/mmcv/fileio/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/parse.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from io import StringIO + +from .file_client import FileClient + + +def list_from_file(filename, + prefix='', + offset=0, + max_num=0, + encoding='utf-8', + file_client_args=None): + """Load a text file and parse the content as a list of strings. + + Note: + In v1.3.16 and later, ``list_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a list for strings. + + Args: + filename (str): Filename. + prefix (str): The prefix to be inserted to the beginning of each item. + offset (int): The offset of lines. + max_num (int): The maximum number of lines to be read, + zeros and negatives mean no limitation. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> list_from_file('/path/of/your/file') # disk + ['hello', 'world'] + >>> list_from_file('s3://path/of/your/file') # ceph or petrel + ['hello', 'world'] + + Returns: + list[str]: A list of strings. + """ + cnt = 0 + item_list = [] + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: + for _ in range(offset): + f.readline() + for line in f: + if 0 < max_num <= cnt: + break + item_list.append(prefix + line.rstrip('\n\r')) + cnt += 1 + return item_list + + +def dict_from_file(filename, + key_type=str, + encoding='utf-8', + file_client_args=None): + """Load a text file and parse the content as a dict. + + Each line of the text file will be two or more columns split by + whitespaces or tabs. The first column will be parsed as dict keys, and + the following columns will be parsed as dict values. + + Note: + In v1.3.16 and later, ``dict_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a dict. + + Args: + filename(str): Filename. + key_type(type): Type of the dict keys. str is user by default and + type conversion will be performed if specified. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dict_from_file('/path/of/your/file') # disk + {'key1': 'value1', 'key2': 'value2'} + >>> dict_from_file('s3://path/of/your/file') # ceph or petrel + {'key1': 'value1', 'key2': 'value2'} + + Returns: + dict: The parsed contents. + """ + mapping = {} + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: + for line in f: + items = line.rstrip('\n').split() + assert len(items) >= 2 + key = key_type(items[0]) + val = items[1:] if len(items) > 2 else items[1] + mapping[key] = val + return mapping diff --git a/annotator/uniformer/mmcv/image/__init__.py b/annotator/uniformer/mmcv/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91 --- /dev/null +++ b/annotator/uniformer/mmcv/image/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr, + gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert, + rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb) +from .geometric import (cutout, imcrop, imflip, imflip_, impad, + impad_to_multiple, imrescale, imresize, imresize_like, + imresize_to_multiple, imrotate, imshear, imtranslate, + rescale_size) +from .io import imfrombytes, imread, imwrite, supported_backends, use_backend +from .misc import tensor2imgs +from .photometric import (adjust_brightness, adjust_color, adjust_contrast, + adjust_lighting, adjust_sharpness, auto_contrast, + clahe, imdenormalize, imequalize, iminvert, + imnormalize, imnormalize_, lut_transform, posterize, + solarize) + +__all__ = [ + 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb', + 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale', + 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size', + 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate', + 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend', + 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', + 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', + 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize', + 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe', + 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting' +] diff --git a/annotator/uniformer/mmcv/image/colorspace.py b/annotator/uniformer/mmcv/image/colorspace.py new file mode 100644 index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add --- /dev/null +++ b/annotator/uniformer/mmcv/image/colorspace.py @@ -0,0 +1,306 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + + +def imconvert(img, src, dst): + """Convert an image from the src colorspace to dst colorspace. + + Args: + img (ndarray): The input image. + src (str): The source colorspace, e.g., 'rgb', 'hsv'. + dst (str): The destination colorspace, e.g., 'rgb', 'hsv'. + + Returns: + ndarray: The converted image. + """ + code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') + out_img = cv2.cvtColor(img, code) + return out_img + + +def bgr2gray(img, keepdim=False): + """Convert a BGR image to grayscale image. + + Args: + img (ndarray): The input image. + keepdim (bool): If False (by default), then return the grayscale image + with 2 dims, otherwise 3 dims. + + Returns: + ndarray: The converted grayscale image. + """ + out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if keepdim: + out_img = out_img[..., None] + return out_img + + +def rgb2gray(img, keepdim=False): + """Convert a RGB image to grayscale image. + + Args: + img (ndarray): The input image. + keepdim (bool): If False (by default), then return the grayscale image + with 2 dims, otherwise 3 dims. + + Returns: + ndarray: The converted grayscale image. + """ + out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + if keepdim: + out_img = out_img[..., None] + return out_img + + +def gray2bgr(img): + """Convert a grayscale image to BGR image. + + Args: + img (ndarray): The input image. + + Returns: + ndarray: The converted BGR image. + """ + img = img[..., None] if img.ndim == 2 else img + out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + return out_img + + +def gray2rgb(img): + """Convert a grayscale image to RGB image. + + Args: + img (ndarray): The input image. + + Returns: + ndarray: The converted RGB image. + """ + img = img[..., None] if img.ndim == 2 else img + out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def convert_color_factory(src, dst): + + code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') + + def convert_color(img): + out_img = cv2.cvtColor(img, code) + return out_img + + convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()} + image. + + Args: + img (ndarray or str): The input image. + + Returns: + ndarray: The converted {dst.upper()} image. + """ + + return convert_color + + +bgr2rgb = convert_color_factory('bgr', 'rgb') + +rgb2bgr = convert_color_factory('rgb', 'bgr') + +bgr2hsv = convert_color_factory('bgr', 'hsv') + +hsv2bgr = convert_color_factory('hsv', 'bgr') + +bgr2hls = convert_color_factory('bgr', 'hls') + +hls2bgr = convert_color_factory('hls', 'bgr') diff --git a/annotator/uniformer/mmcv/image/geometric.py b/annotator/uniformer/mmcv/image/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d --- /dev/null +++ b/annotator/uniformer/mmcv/image/geometric.py @@ -0,0 +1,728 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers + +import cv2 +import numpy as np + +from ..utils import to_2tuple +from .io import imread_backend + +try: + from PIL import Image +except ImportError: + Image = None + + +def _scale_size(size, scale): + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +cv2_interp_codes = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'bicubic': cv2.INTER_CUBIC, + 'area': cv2.INTER_AREA, + 'lanczos': cv2.INTER_LANCZOS4 +} + +if Image is not None: + pillow_interp_codes = { + 'nearest': Image.NEAREST, + 'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'box': Image.BOX, + 'lanczos': Image.LANCZOS, + 'hamming': Image.HAMMING + } + + +def imresize(img, + size, + return_scale=False, + interpolation='bilinear', + out=None, + backend=None): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend is None: + backend = imread_backend + if backend not in ['cv2', 'pillow']: + raise ValueError(f'backend: {backend} is not supported for resize.' + f"Supported backends are 'cv2', 'pillow'") + + if backend == 'pillow': + assert img.dtype == np.uint8, 'Pillow backend only support uint8 type' + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize( + img, size, dst=out, interpolation=cv2_interp_codes[interpolation]) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def imresize_to_multiple(img, + divisor, + size=None, + scale_factor=None, + keep_ratio=False, + return_scale=False, + interpolation='bilinear', + out=None, + backend=None): + """Resize image according to a given size or scale factor and then rounds + up the the resized or rescaled image size to the nearest value that can be + divided by the divisor. + + Args: + img (ndarray): The input image. + divisor (int | tuple): Resized image size will be a multiple of + divisor. If divisor is a tuple, divisor should be + (w_divisor, h_divisor). + size (None | int | tuple[int]): Target size (w, h). Default: None. + scale_factor (None | float | tuple[float]): Multiplier for spatial + size. Should match input size if it is a tuple and the 2D style is + (w_scale_factor, h_scale_factor). Default: None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Default: False. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if size is not None and scale_factor is not None: + raise ValueError('only one of size or scale_factor should be defined') + elif size is None and scale_factor is None: + raise ValueError('one of size or scale_factor should be defined') + elif size is not None: + size = to_2tuple(size) + if keep_ratio: + size = rescale_size((w, h), size, return_scale=False) + else: + size = _scale_size((w, h), scale_factor) + + divisor = to_2tuple(divisor) + size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)]) + resized_img, w_scale, h_scale = imresize( + img, + size, + return_scale=True, + interpolation=interpolation, + out=out, + backend=backend) + if return_scale: + return resized_img, w_scale, h_scale + else: + return resized_img + + +def imresize_like(img, + dst_img, + return_scale=False, + interpolation='bilinear', + backend=None): + """Resize image to the same size of a given image. + + Args: + img (ndarray): The input image. + dst_img (ndarray): The target image. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = dst_img.shape[:2] + return imresize(img, (w, h), return_scale, interpolation, backend=backend) + + +def rescale_size(old_size, scale, return_scale=False): + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f'Invalid scale {scale}, must be positive.') + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError( + f'Scale must be a number or tuple of int, but got {type(scale)}') + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale(img, + scale, + return_scale=False, + interpolation='bilinear', + backend=None): + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize( + img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +def imflip(img, direction='horizontal'): + """Flip an image horizontally or vertically. + + Args: + img (ndarray): Image to be flipped. + direction (str): The flip direction, either "horizontal" or + "vertical" or "diagonal". + + Returns: + ndarray: The flipped image. + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + if direction == 'horizontal': + return np.flip(img, axis=1) + elif direction == 'vertical': + return np.flip(img, axis=0) + else: + return np.flip(img, axis=(0, 1)) + + +def imflip_(img, direction='horizontal'): + """Inplace flip an image horizontally or vertically. + + Args: + img (ndarray): Image to be flipped. + direction (str): The flip direction, either "horizontal" or + "vertical" or "diagonal". + + Returns: + ndarray: The flipped image (inplace). + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + if direction == 'horizontal': + return cv2.flip(img, 1, img) + elif direction == 'vertical': + return cv2.flip(img, 0, img) + else: + return cv2.flip(img, -1, img) + + +def imrotate(img, + angle, + center=None, + scale=1.0, + border_value=0, + interpolation='bilinear', + auto_bound=False): + """Rotate an image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees, positive values mean + clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. + scale (float): Isotropic scale factor. + border_value (int): Border value. + interpolation (str): Same as :func:`resize`. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. + + Returns: + ndarray: The rotated image. + """ + if center is not None and auto_bound: + raise ValueError('`auto_bound` conflicts with `center`') + h, w = img.shape[:2] + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + assert isinstance(center, tuple) + + matrix = cv2.getRotationMatrix2D(center, -angle, scale) + if auto_bound: + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + new_w = h * sin + w * cos + new_h = h * cos + w * sin + matrix[0, 2] += (new_w - w) * 0.5 + matrix[1, 2] += (new_h - h) * 0.5 + w = int(np.round(new_w)) + h = int(np.round(new_h)) + rotated = cv2.warpAffine( + img, + matrix, (w, h), + flags=cv2_interp_codes[interpolation], + borderValue=border_value) + return rotated + + +def bbox_clip(bboxes, img_shape): + """Clip bboxes to fit the image shape. + + Args: + bboxes (ndarray): Shape (..., 4*k) + img_shape (tuple[int]): (height, width) of the image. + + Returns: + ndarray: Clipped bboxes. + """ + assert bboxes.shape[-1] % 4 == 0 + cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype) + cmin[0::2] = img_shape[1] - 1 + cmin[1::2] = img_shape[0] - 1 + clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0) + return clipped_bboxes + + +def bbox_scaling(bboxes, scale, clip_shape=None): + """Scaling bboxes w.r.t the box center. + + Args: + bboxes (ndarray): Shape(..., 4). + scale (float): Scaling factor. + clip_shape (tuple[int], optional): If specified, bboxes that exceed the + boundary will be clipped according to the given shape (h, w). + + Returns: + ndarray: Scaled bboxes. + """ + if float(scale) == 1.0: + scaled_bboxes = bboxes.copy() + else: + w = bboxes[..., 2] - bboxes[..., 0] + 1 + h = bboxes[..., 3] - bboxes[..., 1] + 1 + dw = (w * (scale - 1)) * 0.5 + dh = (h * (scale - 1)) * 0.5 + scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1) + if clip_shape is not None: + return bbox_clip(scaled_bboxes, clip_shape) + else: + return scaled_bboxes + + +def imcrop(img, bboxes, scale=1.0, pad_fill=None): + """Crop image patches. + + 3 steps: scale the bboxes -> clip bboxes -> crop and pad. + + Args: + img (ndarray): Image to be cropped. + bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. + scale (float, optional): Scale ratio of bboxes, the default value + 1.0 means no padding. + pad_fill (Number | list[Number]): Value to be filled for padding. + Default: None, which means no padding. + + Returns: + list[ndarray] | ndarray: The cropped image patches. + """ + chn = 1 if img.ndim == 2 else img.shape[2] + if pad_fill is not None: + if isinstance(pad_fill, (int, float)): + pad_fill = [pad_fill for _ in range(chn)] + assert len(pad_fill) == chn + + _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes + scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32) + clipped_bbox = bbox_clip(scaled_bboxes, img.shape) + + patches = [] + for i in range(clipped_bbox.shape[0]): + x1, y1, x2, y2 = tuple(clipped_bbox[i, :]) + if pad_fill is None: + patch = img[y1:y2 + 1, x1:x2 + 1, ...] + else: + _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :]) + if chn == 1: + patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1) + else: + patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn) + patch = np.array( + pad_fill, dtype=img.dtype) * np.ones( + patch_shape, dtype=img.dtype) + x_start = 0 if _x1 >= 0 else -_x1 + y_start = 0 if _y1 >= 0 else -_y1 + w = x2 - x1 + 1 + h = y2 - y1 + 1 + patch[y_start:y_start + h, x_start:x_start + w, + ...] = img[y1:y1 + h, x1:x1 + w, ...] + patches.append(patch) + + if bboxes.ndim == 1: + return patches[0] + else: + return patches + + +def impad(img, + *, + shape=None, + padding=None, + pad_val=0, + padding_mode='constant'): + """Pad the given image to a certain shape or pad on all sides with + specified padding mode and padding value. + + Args: + img (ndarray): Image to be padded. + shape (tuple[int]): Expected padding shape (h, w). Default: None. + padding (int or tuple[int]): Padding on each border. If a single int is + provided this is used to pad all borders. If tuple of length 2 is + provided this is the padding on left/right and top/bottom + respectively. If a tuple of length 4 is provided this is the + padding for the left, top, right and bottom borders respectively. + Default: None. Note that `shape` and `padding` can not be both + set. + pad_val (Number | Sequence[Number]): Values to be filled in padding + areas when padding_mode is 'constant'. Default: 0. + padding_mode (str): Type of padding. Should be: constant, edge, + reflect or symmetric. Default: constant. + + - constant: pads with a constant value, this value is specified + with pad_val. + - edge: pads with the last value at the edge of the image. + - reflect: pads with reflection of image without repeating the + last value on the edge. For example, padding [1, 2, 3, 4] + with 2 elements on both sides in reflect mode will result + in [3, 2, 1, 2, 3, 4, 3, 2]. + - symmetric: pads with reflection of image repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with + 2 elements on both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + ndarray: The padded image. + """ + + assert (shape is not None) ^ (padding is not None) + if shape is not None: + padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0]) + + # check pad_val + if isinstance(pad_val, tuple): + assert len(pad_val) == img.shape[-1] + elif not isinstance(pad_val, numbers.Number): + raise TypeError('pad_val must be a int or a tuple. ' + f'But received {type(pad_val)}') + + # check padding + if isinstance(padding, tuple) and len(padding) in [2, 4]: + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + elif isinstance(padding, numbers.Number): + padding = (padding, padding, padding, padding) + else: + raise ValueError('Padding must be a int or a 2, or 4 element tuple.' + f'But received {padding}') + + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + + border_type = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT + } + img = cv2.copyMakeBorder( + img, + padding[1], + padding[3], + padding[0], + padding[2], + border_type[padding_mode], + value=pad_val) + + return img + + +def impad_to_multiple(img, divisor, pad_val=0): + """Pad an image to ensure each edge to be multiple to some number. + + Args: + img (ndarray): Image to be padded. + divisor (int): Padded image edges will be multiple to divisor. + pad_val (Number | Sequence[Number]): Same as :func:`impad`. + + Returns: + ndarray: The padded image. + """ + pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor + pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor + return impad(img, shape=(pad_h, pad_w), pad_val=pad_val) + + +def cutout(img, shape, pad_val=0): + """Randomly cut out a rectangle from the original img. + + Args: + img (ndarray): Image to be cutout. + shape (int | tuple[int]): Expected cutout shape (h, w). If given as a + int, the value will be used for both h and w. + pad_val (int | float | tuple[int | float]): Values to be filled in the + cut area. Defaults to 0. + + Returns: + ndarray: The cutout image. + """ + + channels = 1 if img.ndim == 2 else img.shape[2] + if isinstance(shape, int): + cut_h, cut_w = shape, shape + else: + assert isinstance(shape, tuple) and len(shape) == 2, \ + f'shape must be a int or a tuple with length 2, but got type ' \ + f'{type(shape)} instead.' + cut_h, cut_w = shape + if isinstance(pad_val, (int, float)): + pad_val = tuple([pad_val] * channels) + elif isinstance(pad_val, tuple): + assert len(pad_val) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(pad_val), channels) + else: + raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`') + + img_h, img_w = img.shape[:2] + y0 = np.random.uniform(img_h) + x0 = np.random.uniform(img_w) + + y1 = int(max(0, y0 - cut_h / 2.)) + x1 = int(max(0, x0 - cut_w / 2.)) + y2 = min(img_h, y1 + cut_h) + x2 = min(img_w, x1 + cut_w) + + if img.ndim == 2: + patch_shape = (y2 - y1, x2 - x1) + else: + patch_shape = (y2 - y1, x2 - x1, channels) + + img_cutout = img.copy() + patch = np.array( + pad_val, dtype=img.dtype) * np.ones( + patch_shape, dtype=img.dtype) + img_cutout[y1:y2, x1:x2, ...] = patch + + return img_cutout + + +def _get_shear_matrix(magnitude, direction='horizontal'): + """Generate the shear matrix for transformation. + + Args: + magnitude (int | float): The magnitude used for shear. + direction (str): The flip direction, either "horizontal" + or "vertical". + + Returns: + ndarray: The shear matrix with dtype float32. + """ + if direction == 'horizontal': + shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]]) + elif direction == 'vertical': + shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]]) + return shear_matrix + + +def imshear(img, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear an image. + + Args: + img (ndarray): Image to be sheared with format (h, w) + or (h, w, c). + magnitude (int | float): The magnitude used for shear. + direction (str): The flip direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The sheared image. + """ + assert direction in ['horizontal', + 'vertical'], f'Invalid direction: {direction}' + height, width = img.shape[:2] + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = img.shape[-1] + if isinstance(border_value, int): + border_value = tuple([border_value] * channels) + elif isinstance(border_value, tuple): + assert len(border_value) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(border_value), channels) + else: + raise ValueError( + f'Invalid type {type(border_value)} for `border_value`') + shear_matrix = _get_shear_matrix(magnitude, direction) + sheared = cv2.warpAffine( + img, + shear_matrix, + (width, height), + # Note case when the number elements in `border_value` + # greater than 3 (e.g. shearing masks whose channels large + # than 3) will raise TypeError in `cv2.warpAffine`. + # Here simply slice the first 3 values in `border_value`. + borderValue=border_value[:3], + flags=cv2_interp_codes[interpolation]) + return sheared + + +def _get_translate_matrix(offset, direction='horizontal'): + """Generate the translate matrix. + + Args: + offset (int | float): The offset used for translate. + direction (str): The translate direction, either + "horizontal" or "vertical". + + Returns: + ndarray: The translate matrix with dtype float32. + """ + if direction == 'horizontal': + translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]]) + elif direction == 'vertical': + translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]]) + return translate_matrix + + +def imtranslate(img, + offset, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Translate an image. + + Args: + img (ndarray): Image to be translated with format + (h, w) or (h, w, c). + offset (int | float): The offset used for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The translated image. + """ + assert direction in ['horizontal', + 'vertical'], f'Invalid direction: {direction}' + height, width = img.shape[:2] + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = img.shape[-1] + if isinstance(border_value, int): + border_value = tuple([border_value] * channels) + elif isinstance(border_value, tuple): + assert len(border_value) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(border_value), channels) + else: + raise ValueError( + f'Invalid type {type(border_value)} for `border_value`.') + translate_matrix = _get_translate_matrix(offset, direction) + translated = cv2.warpAffine( + img, + translate_matrix, + (width, height), + # Note case when the number elements in `border_value` + # greater than 3 (e.g. translating masks whose channels + # large than 3) will raise TypeError in `cv2.warpAffine`. + # Here simply slice the first 3 values in `border_value`. + borderValue=border_value[:3], + flags=cv2_interp_codes[interpolation]) + return translated diff --git a/annotator/uniformer/mmcv/image/io.py b/annotator/uniformer/mmcv/image/io.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fa2e8cc06b1a7b0b69de6406980b15d61a1e5d --- /dev/null +++ b/annotator/uniformer/mmcv/image/io.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import os.path as osp +from pathlib import Path + +import cv2 +import numpy as np +from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION, + IMREAD_UNCHANGED) + +from annotator.uniformer.mmcv.utils import check_file_exist, is_str, mkdir_or_exist + +try: + from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG +except ImportError: + TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None + +try: + from PIL import Image, ImageOps +except ImportError: + Image = None + +try: + import tifffile +except ImportError: + tifffile = None + +jpeg = None +supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile'] + +imread_flags = { + 'color': IMREAD_COLOR, + 'grayscale': IMREAD_GRAYSCALE, + 'unchanged': IMREAD_UNCHANGED, + 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR, + 'grayscale_ignore_orientation': + IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE +} + +imread_backend = 'cv2' + + +def use_backend(backend): + """Select a backend for image decoding. + + Args: + backend (str): The image decoding backend type. Options are `cv2`, + `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG) + and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg` + file format. + """ + assert backend in supported_backends + global imread_backend + imread_backend = backend + if imread_backend == 'turbojpeg': + if TurboJPEG is None: + raise ImportError('`PyTurboJPEG` is not installed') + global jpeg + if jpeg is None: + jpeg = TurboJPEG() + elif imread_backend == 'pillow': + if Image is None: + raise ImportError('`Pillow` is not installed') + elif imread_backend == 'tifffile': + if tifffile is None: + raise ImportError('`tifffile` is not installed') + + +def _jpegflag(flag='color', channel_order='bgr'): + channel_order = channel_order.lower() + if channel_order not in ['rgb', 'bgr']: + raise ValueError('channel order must be either "rgb" or "bgr"') + + if flag == 'color': + if channel_order == 'bgr': + return TJPF_BGR + elif channel_order == 'rgb': + return TJCS_RGB + elif flag == 'grayscale': + return TJPF_GRAY + else: + raise ValueError('flag must be "color" or "grayscale"') + + +def _pillow2array(img, flag='color', channel_order='bgr'): + """Convert a pillow image to numpy array. + + Args: + img (:obj:`PIL.Image.Image`): The image loaded using PIL + flag (str): Flags specifying the color type of a loaded image, + candidates are 'color', 'grayscale' and 'unchanged'. + Default to 'color'. + channel_order (str): The channel order of the output image array, + candidates are 'bgr' and 'rgb'. Default to 'bgr'. + + Returns: + np.ndarray: The converted numpy array + """ + channel_order = channel_order.lower() + if channel_order not in ['rgb', 'bgr']: + raise ValueError('channel order must be either "rgb" or "bgr"') + + if flag == 'unchanged': + array = np.array(img) + if array.ndim >= 3 and array.shape[2] >= 3: # color image + array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR + else: + # Handle exif orientation tag + if flag in ['color', 'grayscale']: + img = ImageOps.exif_transpose(img) + # If the image mode is not 'RGB', convert it to 'RGB' first. + if img.mode != 'RGB': + if img.mode != 'LA': + # Most formats except 'LA' can be directly converted to RGB + img = img.convert('RGB') + else: + # When the mode is 'LA', the default conversion will fill in + # the canvas with black, which sometimes shadows black objects + # in the foreground. + # + # Therefore, a random color (124, 117, 104) is used for canvas + img_rgba = img.convert('RGBA') + img = Image.new('RGB', img_rgba.size, (124, 117, 104)) + img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha + if flag in ['color', 'color_ignore_orientation']: + array = np.array(img) + if channel_order != 'rgb': + array = array[:, :, ::-1] # RGB to BGR + elif flag in ['grayscale', 'grayscale_ignore_orientation']: + img = img.convert('L') + array = np.array(img) + else: + raise ValueError( + 'flag must be "color", "grayscale", "unchanged", ' + f'"color_ignore_orientation" or "grayscale_ignore_orientation"' + f' but got {flag}') + return array + + +def imread(img_or_path, flag='color', channel_order='bgr', backend=None): + """Read an image. + + Args: + img_or_path (ndarray or str or Path): Either a numpy array or str or + pathlib.Path. If it is a numpy array (loaded image), then + it will be returned as is. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale`, `unchanged`, + `color_ignore_orientation` and `grayscale_ignore_orientation`. + By default, `cv2` and `pillow` backend would rotate the image + according to its EXIF info unless called with `unchanged` or + `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend + always ignore image's EXIF info regardless of the flag. + The `turbojpeg` backend only supports `color` and `grayscale`. + channel_order (str): Order of channel, candidates are `bgr` and `rgb`. + backend (str | None): The image decoding backend type. Options are + `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. + If backend is None, the global imread_backend specified by + ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + ndarray: Loaded image array. + """ + + if backend is None: + backend = imread_backend + if backend not in supported_backends: + raise ValueError(f'backend: {backend} is not supported. Supported ' + "backends are 'cv2', 'turbojpeg', 'pillow'") + if isinstance(img_or_path, Path): + img_or_path = str(img_or_path) + + if isinstance(img_or_path, np.ndarray): + return img_or_path + elif is_str(img_or_path): + check_file_exist(img_or_path, + f'img file does not exist: {img_or_path}') + if backend == 'turbojpeg': + with open(img_or_path, 'rb') as in_file: + img = jpeg.decode(in_file.read(), + _jpegflag(flag, channel_order)) + if img.shape[-1] == 1: + img = img[:, :, 0] + return img + elif backend == 'pillow': + img = Image.open(img_or_path) + img = _pillow2array(img, flag, channel_order) + return img + elif backend == 'tifffile': + img = tifffile.imread(img_or_path) + return img + else: + flag = imread_flags[flag] if is_str(flag) else flag + img = cv2.imread(img_or_path, flag) + if flag == IMREAD_COLOR and channel_order == 'rgb': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img + else: + raise TypeError('"img" must be a numpy array or a str or ' + 'a pathlib.Path object') + + +def imfrombytes(content, flag='color', channel_order='bgr', backend=None): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Same as :func:`imread`. + backend (str | None): The image decoding backend type. Options are + `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the + global imread_backend specified by ``mmcv.use_backend()`` will be + used. Default: None. + + Returns: + ndarray: Loaded image array. + """ + + if backend is None: + backend = imread_backend + if backend not in supported_backends: + raise ValueError(f'backend: {backend} is not supported. Supported ' + "backends are 'cv2', 'turbojpeg', 'pillow'") + if backend == 'turbojpeg': + img = jpeg.decode(content, _jpegflag(flag, channel_order)) + if img.shape[-1] == 1: + img = img[:, :, 0] + return img + elif backend == 'pillow': + buff = io.BytesIO(content) + img = Image.open(buff) + img = _pillow2array(img, flag, channel_order) + return img + else: + img_np = np.frombuffer(content, np.uint8) + flag = imread_flags[flag] if is_str(flag) else flag + img = cv2.imdecode(img_np, flag) + if flag == IMREAD_COLOR and channel_order == 'rgb': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = osp.abspath(osp.dirname(file_path)) + mkdir_or_exist(dir_name) + return cv2.imwrite(file_path, img, params) diff --git a/annotator/uniformer/mmcv/image/misc.py b/annotator/uniformer/mmcv/image/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3e61f05e3b05e4c7b40de4eb6c8eb100e6da41d0 --- /dev/null +++ b/annotator/uniformer/mmcv/image/misc.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +import annotator.uniformer.mmcv as mmcv + +try: + import torch +except ImportError: + torch = None + + +def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): + """Convert tensor to 3-channel images. + + Args: + tensor (torch.Tensor): Tensor that contains multiple images, shape ( + N, C, H, W). + mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0). + std (tuple[float], optional): Standard deviation of images. + Defaults to (1, 1, 1). + to_rgb (bool, optional): Whether the tensor was converted to RGB + format in the first place. If so, convert it back to BGR. + Defaults to True. + + Returns: + list[np.ndarray]: A list that contains multiple images. + """ + + if torch is None: + raise RuntimeError('pytorch is not installed') + assert torch.is_tensor(tensor) and tensor.ndim == 4 + assert len(mean) == 3 + assert len(std) == 3 + + num_imgs = tensor.size(0) + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + imgs = [] + for img_id in range(num_imgs): + img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0) + img = mmcv.imdenormalize( + img, mean, std, to_bgr=to_rgb).astype(np.uint8) + imgs.append(np.ascontiguousarray(img)) + return imgs diff --git a/annotator/uniformer/mmcv/image/photometric.py b/annotator/uniformer/mmcv/image/photometric.py new file mode 100644 index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae --- /dev/null +++ b/annotator/uniformer/mmcv/image/photometric.py @@ -0,0 +1,428 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +from ..utils import is_tuple_of +from .colorspace import bgr2gray, gray2bgr + + +def imnormalize(img, mean, std, to_rgb=True): + """Normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + img = img.copy().astype(np.float32) + return imnormalize_(img, mean, std, to_rgb) + + +def imnormalize_(img, mean, std, to_rgb=True): + """Inplace normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + # cv2 inplace normalization does not accept uint8 + assert img.dtype != np.uint8 + mean = np.float64(mean.reshape(1, -1)) + stdinv = 1 / np.float64(std.reshape(1, -1)) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + cv2.subtract(img, mean, img) # inplace + cv2.multiply(img, stdinv, img) # inplace + return img + + +def imdenormalize(img, mean, std, to_bgr=True): + assert img.dtype != np.uint8 + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = cv2.multiply(img, std) # make a copy + cv2.add(img, mean, img) # inplace + if to_bgr: + cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace + return img + + +def iminvert(img): + """Invert (negate) an image. + + Args: + img (ndarray): Image to be inverted. + + Returns: + ndarray: The inverted image. + """ + return np.full_like(img, 255) - img + + +def solarize(img, thr=128): + """Solarize an image (invert all pixel values above a threshold) + + Args: + img (ndarray): Image to be solarized. + thr (int): Threshold for solarizing (0 - 255). + + Returns: + ndarray: The solarized image. + """ + img = np.where(img < thr, img, 255 - img) + return img + + +def posterize(img, bits): + """Posterize an image (reduce the number of bits for each color channel) + + Args: + img (ndarray): Image to be posterized. + bits (int): Number of bits (1 to 8) to use for posterizing. + + Returns: + ndarray: The posterized image. + """ + shift = 8 - bits + img = np.left_shift(np.right_shift(img, shift), shift) + return img + + +def adjust_color(img, alpha=1, beta=None, gamma=0): + r"""It blends the source image and its gray image: + + .. math:: + output = img * alpha + gray\_img * beta + gamma + + Args: + img (ndarray): The input source image. + alpha (int | float): Weight for the source image. Default 1. + beta (int | float): Weight for the converted gray image. + If None, it's assigned the value (1 - `alpha`). + gamma (int | float): Scalar added to each sum. + Same as :func:`cv2.addWeighted`. Default 0. + + Returns: + ndarray: Colored image which has the same size and dtype as input. + """ + gray_img = bgr2gray(img) + gray_img = np.tile(gray_img[..., None], [1, 1, 3]) + if beta is None: + beta = 1 - alpha + colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma) + if not colored_img.dtype == np.uint8: + # Note when the dtype of `img` is not the default `np.uint8` + # (e.g. np.float32), the value in `colored_img` got from cv2 + # is not guaranteed to be in range [0, 255], so here clip + # is needed. + colored_img = np.clip(colored_img, 0, 255) + return colored_img + + +def imequalize(img): + """Equalize the image histogram. + + This function applies a non-linear mapping to the input image, + in order to create a uniform distribution of grayscale values + in the output image. + + Args: + img (ndarray): Image to be equalized. + + Returns: + ndarray: The equalized image. + """ + + def _scale_channel(im, c): + """Scale the data in the corresponding channel.""" + im = im[:, :, c] + # Compute the histogram of the image channel. + histo = np.histogram(im, 256, (0, 255))[0] + # For computing the step, filter out the nonzeros. + nonzero_histo = histo[histo > 0] + step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255 + if not step: + lut = np.array(range(256)) + else: + # Compute the cumulative sum, shifted by step // 2 + # and then normalized by step. + lut = (np.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = np.concatenate([[0], lut[:-1]], 0) + # handle potential integer overflow + lut[lut > 255] = 255 + # If step is zero, return the original image. + # Otherwise, index from lut. + return np.where(np.equal(step, 0), im, lut[im]) + + # Scales each channel independently and then stacks + # the result. + s1 = _scale_channel(img, 0) + s2 = _scale_channel(img, 1) + s3 = _scale_channel(img, 2) + equalized_img = np.stack([s1, s2, s3], axis=-1) + return equalized_img.astype(img.dtype) + + +def adjust_brightness(img, factor=1.): + """Adjust image brightness. + + This function controls the brightness of an image. An + enhancement factor of 0.0 gives a black image. + A factor of 1.0 gives the original image. This function + blends the source image and the degenerated black image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be brightened. + factor (float): A value controls the enhancement. + Factor 1.0 returns the original image, lower + factors mean less color (brightness, contrast, + etc), and higher values more. Default 1. + + Returns: + ndarray: The brightened image. + """ + degenerated = np.zeros_like(img) + # Note manually convert the dtype to np.float32, to + # achieve as close results as PIL.ImageEnhance.Brightness. + # Set beta=1-factor, and gamma=0 + brightened_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + brightened_img = np.clip(brightened_img, 0, 255) + return brightened_img.astype(img.dtype) + + +def adjust_contrast(img, factor=1.): + """Adjust image contrast. + + This function controls the contrast of an image. An + enhancement factor of 0.0 gives a solid grey + image. A factor of 1.0 gives the original image. It + blends the source image and the degenerated mean image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be contrasted. BGR order. + factor (float): Same as :func:`mmcv.adjust_brightness`. + + Returns: + ndarray: The contrasted image. + """ + gray_img = bgr2gray(img) + hist = np.histogram(gray_img, 256, (0, 255))[0] + mean = round(np.sum(gray_img) / np.sum(hist)) + degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype) + degenerated = gray2bgr(degenerated) + contrasted_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + contrasted_img = np.clip(contrasted_img, 0, 255) + return contrasted_img.astype(img.dtype) + + +def auto_contrast(img, cutoff=0): + """Auto adjust image contrast. + + This function maximize (normalize) image contrast by first removing cutoff + percent of the lightest and darkest pixels from the histogram and remapping + the image so that the darkest pixel becomes black (0), and the lightest + becomes white (255). + + Args: + img (ndarray): Image to be contrasted. BGR order. + cutoff (int | float | tuple): The cutoff percent of the lightest and + darkest pixels to be removed. If given as tuple, it shall be + (low, high). Otherwise, the single value will be used for both. + Defaults to 0. + + Returns: + ndarray: The contrasted image. + """ + + def _auto_contrast_channel(im, c, cutoff): + im = im[:, :, c] + # Compute the histogram of the image channel. + histo = np.histogram(im, 256, (0, 255))[0] + # Remove cut-off percent pixels from histo + histo_sum = np.cumsum(histo) + cut_low = histo_sum[-1] * cutoff[0] // 100 + cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100 + histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low + histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0) + + # Compute mapping + low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1] + # If all the values have been cut off, return the origin img + if low >= high: + return im + scale = 255.0 / (high - low) + offset = -low * scale + lut = np.array(range(256)) + lut = lut * scale + offset + lut = np.clip(lut, 0, 255) + return lut[im] + + if isinstance(cutoff, (int, float)): + cutoff = (cutoff, cutoff) + else: + assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \ + f'float or tuple, but got {type(cutoff)} instead.' + # Auto adjusts contrast for each channel independently and then stacks + # the result. + s1 = _auto_contrast_channel(img, 0, cutoff) + s2 = _auto_contrast_channel(img, 1, cutoff) + s3 = _auto_contrast_channel(img, 2, cutoff) + contrasted_img = np.stack([s1, s2, s3], axis=-1) + return contrasted_img.astype(img.dtype) + + +def adjust_sharpness(img, factor=1., kernel=None): + """Adjust image sharpness. + + This function controls the sharpness of an image. An + enhancement factor of 0.0 gives a blurred image. A + factor of 1.0 gives the original image. And a factor + of 2.0 gives a sharpened image. It blends the source + image and the degenerated mean image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be sharpened. BGR order. + factor (float): Same as :func:`mmcv.adjust_brightness`. + kernel (np.ndarray, optional): Filter kernel to be applied on the img + to obtain the degenerated img. Defaults to None. + + Note: + No value sanity check is enforced on the kernel set by users. So with + an inappropriate kernel, the ``adjust_sharpness`` may fail to perform + the function its name indicates but end up performing whatever + transform determined by the kernel. + + Returns: + ndarray: The sharpened image. + """ + + if kernel is None: + # adopted from PIL.ImageFilter.SMOOTH + kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13 + assert isinstance(kernel, np.ndarray), \ + f'kernel must be of type np.ndarray, but got {type(kernel)} instead.' + assert kernel.ndim == 2, \ + f'kernel must have a dimension of 2, but got {kernel.ndim} instead.' + + degenerated = cv2.filter2D(img, -1, kernel) + sharpened_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + sharpened_img = np.clip(sharpened_img, 0, 255) + return sharpened_img.astype(img.dtype) + + +def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True): + """AlexNet-style PCA jitter. + + This data augmentation is proposed in `ImageNet Classification with Deep + Convolutional Neural Networks + `_. + + Args: + img (ndarray): Image to be adjusted lighting. BGR order. + eigval (ndarray): the eigenvalue of the convariance matrix of pixel + values, respectively. + eigvec (ndarray): the eigenvector of the convariance matrix of pixel + values, respectively. + alphastd (float): The standard deviation for distribution of alpha. + Defaults to 0.1 + to_rgb (bool): Whether to convert img to rgb. + + Returns: + ndarray: The adjusted image. + """ + assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \ + f'eigval and eigvec should both be of type np.ndarray, got ' \ + f'{type(eigval)} and {type(eigvec)} instead.' + + assert eigval.ndim == 1 and eigvec.ndim == 2 + assert eigvec.shape == (3, eigval.shape[0]) + n_eigval = eigval.shape[0] + assert isinstance(alphastd, float), 'alphastd should be of type float, ' \ + f'got {type(alphastd)} instead.' + + img = img.copy().astype(np.float32) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + + alpha = np.random.normal(0, alphastd, n_eigval) + alter = eigvec \ + * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \ + * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval)) + alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape) + img_adjusted = img + alter + return img_adjusted + + +def lut_transform(img, lut_table): + """Transform array by look-up table. + + The function lut_transform fills the output array with values from the + look-up table. Indices of the entries are taken from the input array. + + Args: + img (ndarray): Image to be transformed. + lut_table (ndarray): look-up table of 256 elements; in case of + multi-channel input array, the table should either have a single + channel (in this case the same table is used for all channels) or + the same number of channels as in the input array. + + Returns: + ndarray: The transformed image. + """ + assert isinstance(img, np.ndarray) + assert 0 <= np.min(img) and np.max(img) <= 255 + assert isinstance(lut_table, np.ndarray) + assert lut_table.shape == (256, ) + + return cv2.LUT(np.array(img, dtype=np.uint8), lut_table) + + +def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + img (ndarray): Image to be processed. + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + + Returns: + ndarray: The processed image. + """ + assert isinstance(img, np.ndarray) + assert img.ndim == 2 + assert isinstance(clip_limit, (float, int)) + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + + clahe = cv2.createCLAHE(clip_limit, tile_grid_size) + return clahe.apply(np.array(img, dtype=np.uint8)) diff --git a/annotator/uniformer/mmcv/model_zoo/deprecated.json b/annotator/uniformer/mmcv/model_zoo/deprecated.json new file mode 100644 index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b --- /dev/null +++ b/annotator/uniformer/mmcv/model_zoo/deprecated.json @@ -0,0 +1,6 @@ +{ + "resnet50_caffe": "detectron/resnet50_caffe", + "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr", + "resnet101_caffe": "detectron/resnet101_caffe", + "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr" +} diff --git a/annotator/uniformer/mmcv/model_zoo/mmcls.json b/annotator/uniformer/mmcv/model_zoo/mmcls.json new file mode 100644 index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f --- /dev/null +++ b/annotator/uniformer/mmcv/model_zoo/mmcls.json @@ -0,0 +1,31 @@ +{ + "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth", + "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth", + "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth", + "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth", + "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth", + "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth", + "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth", + "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth", + "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth", + "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth", + "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth", + "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth", + "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth", + "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth", + "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth", + "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth", + "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth", + "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth", + "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth", + "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth", + "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth", + "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth", + "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth", + "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth", + "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth", + "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth", + "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth", + "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth", + "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth" +} diff --git a/annotator/uniformer/mmcv/model_zoo/open_mmlab.json b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json new file mode 100644 index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0 --- /dev/null +++ b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json @@ -0,0 +1,50 @@ +{ + "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth", + "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth", + "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth", + "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth", + "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth", + "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth", + "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth", + "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth", + "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth", + "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth", + "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth", + "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth", + "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth", + "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth", + "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth", + "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth", + "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth", + "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth", + "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth", + "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth", + "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth", + "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth", + "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth", + "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth", + "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth", + "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth", + "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth", + "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth", + "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth", + "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth", + "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth", + "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth", + "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth", + "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth", + "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth", + "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth", + "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth", + "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth", + "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth", + "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth", + "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth", + "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth", + "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth", + "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth", + "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth", + "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth", + "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth", + "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth" +} diff --git a/annotator/uniformer/mmcv/ops/__init__.py b/annotator/uniformer/mmcv/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd --- /dev/null +++ b/annotator/uniformer/mmcv/ops/__init__.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assign_score_withk import assign_score_withk +from .ball_query import ball_query +from .bbox import bbox_overlaps +from .border_align import BorderAlign, border_align +from .box_iou_rotated import box_iou_rotated +from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive +from .cc_attention import CrissCrossAttention +from .contour_expand import contour_expand +from .corner_pool import CornerPool +from .correlation import Correlation +from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d +from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack, + ModulatedDeformRoIPoolPack, deform_roi_pool) +from .deprecated_wrappers import Conv2d_deprecated as Conv2d +from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d +from .deprecated_wrappers import Linear_deprecated as Linear +from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d +from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, + sigmoid_focal_loss, softmax_focal_loss) +from .furthest_point_sample import (furthest_point_sample, + furthest_point_sample_with_dist) +from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu +from .gather_points import gather_points +from .group_points import GroupAll, QueryAndGroup, grouping_operation +from .info import (get_compiler_version, get_compiling_cuda_version, + get_onnxruntime_op_path) +from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev +from .knn import knn +from .masked_conv import MaskedConv2d, masked_conv2d +from .modulated_deform_conv import (ModulatedDeformConv2d, + ModulatedDeformConv2dPack, + modulated_deform_conv2d) +from .multi_scale_deform_attn import MultiScaleDeformableAttention +from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms +from .pixel_group import pixel_group +from .point_sample import (SimpleRoIAlign, point_sample, + rel_roi_point_to_rel_img_point) +from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu, + points_in_boxes_part) +from .points_sampler import PointsSampler +from .psa_mask import PSAMask +from .roi_align import RoIAlign, roi_align +from .roi_align_rotated import RoIAlignRotated, roi_align_rotated +from .roi_pool import RoIPool, roi_pool +from .roiaware_pool3d import RoIAwarePool3d +from .roipoint_pool3d import RoIPointPool3d +from .saconv import SAConv2d +from .scatter_points import DynamicScatter, dynamic_scatter +from .sync_bn import SyncBatchNorm +from .three_interpolate import three_interpolate +from .three_nn import three_nn +from .tin_shift import TINShift, tin_shift +from .upfirdn2d import upfirdn2d +from .voxelize import Voxelization, voxelization + +__all__ = [ + 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', + 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', + 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', + 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', + 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', + 'get_compiler_version', 'get_compiling_cuda_version', + 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', + 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', + 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', + 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', + 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', + 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', + 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', + 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', + 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', + 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn', + 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', + 'border_align', 'gather_points', 'furthest_point_sample', + 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', + 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization', + 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', + 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all' +] diff --git a/annotator/uniformer/mmcv/ops/assign_score_withk.py b/annotator/uniformer/mmcv/ops/assign_score_withk.py new file mode 100644 index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/assign_score_withk.py @@ -0,0 +1,123 @@ +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward']) + + +class AssignScoreWithK(Function): + r"""Perform weighted sum to generate output features according to scores. + Modified from `PAConv `_. + + This is a memory-efficient CUDA implementation of assign_scores operation, + which first transform all point features with weight bank, then assemble + neighbor features with ``knn_idx`` and perform weighted sum of ``scores``. + + See the `paper `_ appendix Sec. D for + more detailed descriptions. + + Note: + This implementation assumes using ``neighbor`` kernel input, which is + (point_features - center_features, point_features). + See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/ + pointnet2/paconv.py#L128 for more details. + """ + + @staticmethod + def forward(ctx, + scores, + point_features, + center_features, + knn_idx, + aggregate='sum'): + """ + Args: + scores (torch.Tensor): (B, npoint, K, M), predicted scores to + aggregate weight matrices in the weight bank. + ``npoint`` is the number of sampled centers. + ``K`` is the number of queried neighbors. + ``M`` is the number of weight matrices in the weight bank. + point_features (torch.Tensor): (B, N, M, out_dim) + Pre-computed point features to be aggregated. + center_features (torch.Tensor): (B, N, M, out_dim) + Pre-computed center features to be aggregated. + knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN. + We assume the first idx in each row is the idx of the center. + aggregate (str, optional): Aggregation method. + Can be 'sum', 'avg' or 'max'. Defaults: 'sum'. + + Returns: + torch.Tensor: (B, out_dim, npoint, K), the aggregated features. + """ + agg = {'sum': 0, 'avg': 1, 'max': 2} + + B, N, M, out_dim = point_features.size() + _, npoint, K, _ = scores.size() + + output = point_features.new_zeros((B, out_dim, npoint, K)) + ext_module.assign_score_withk_forward( + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + output, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg[aggregate]) + + ctx.save_for_backward(output, point_features, center_features, scores, + knn_idx) + ctx.agg = agg[aggregate] + + return output + + @staticmethod + def backward(ctx, grad_out): + """ + Args: + grad_out (torch.Tensor): (B, out_dim, npoint, K) + + Returns: + grad_scores (torch.Tensor): (B, npoint, K, M) + grad_point_features (torch.Tensor): (B, N, M, out_dim) + grad_center_features (torch.Tensor): (B, N, M, out_dim) + """ + _, point_features, center_features, scores, knn_idx = ctx.saved_tensors + + agg = ctx.agg + + B, N, M, out_dim = point_features.size() + _, npoint, K, _ = scores.size() + + grad_point_features = point_features.new_zeros(point_features.shape) + grad_center_features = center_features.new_zeros(center_features.shape) + grad_scores = scores.new_zeros(scores.shape) + + ext_module.assign_score_withk_backward( + grad_out.contiguous(), + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + grad_point_features, + grad_center_features, + grad_scores, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg) + + return grad_scores, grad_point_features, \ + grad_center_features, None, None + + +assign_score_withk = AssignScoreWithK.apply diff --git a/annotator/uniformer/mmcv/ops/ball_query.py b/annotator/uniformer/mmcv/ops/ball_query.py new file mode 100644 index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a --- /dev/null +++ b/annotator/uniformer/mmcv/ops/ball_query.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['ball_query_forward']) + + +class BallQuery(Function): + """Find nearby points in spherical space.""" + + @staticmethod + def forward(ctx, min_radius: float, max_radius: float, sample_num: int, + xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor: + """ + Args: + min_radius (float): minimum radius of the balls. + max_radius (float): maximum radius of the balls. + sample_num (int): maximum number of features in the balls. + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) centers of the ball query. + + Returns: + Tensor: (B, npoint, nsample) tensor with the indices of + the features that form the query balls. + """ + assert center_xyz.is_contiguous() + assert xyz.is_contiguous() + assert min_radius < max_radius + + B, N, _ = xyz.size() + npoint = center_xyz.size(1) + idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int) + + ext_module.ball_query_forward( + center_xyz, + xyz, + idx, + b=B, + n=N, + m=npoint, + min_radius=min_radius, + max_radius=max_radius, + nsample=sample_num) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply diff --git a/annotator/uniformer/mmcv/ops/bbox.py b/annotator/uniformer/mmcv/ops/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/bbox.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps']) + + +def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0): + """Calculate overlap between two set of bboxes. + + If ``aligned`` is ``False``, then calculate the ious between each bbox + of bboxes1 and bboxes2, otherwise the ious between each aligned pair of + bboxes1 and bboxes2. + + Args: + bboxes1 (Tensor): shape (m, 4) in format or empty. + bboxes2 (Tensor): shape (n, 4) in format or empty. + If aligned is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union) or iof (intersection over + foreground). + + Returns: + ious(Tensor): shape (m, n) if aligned == False else shape (m, 1) + + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> bbox_overlaps(bboxes1, bboxes2) + tensor([[0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000]]) + + Example: + >>> empty = torch.FloatTensor([]) + >>> nonempty = torch.FloatTensor([ + >>> [0, 0, 10, 9], + >>> ]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) + """ + + mode_dict = {'iou': 0, 'iof': 1} + assert mode in mode_dict.keys() + mode_flag = mode_dict[mode] + # Either the boxes are empty or the length of boxes' last dimension is 4 + assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) + assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) + assert offset == 1 or offset == 0 + + rows = bboxes1.size(0) + cols = bboxes2.size(0) + if aligned: + assert rows == cols + + if rows * cols == 0: + return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols) + + if aligned: + ious = bboxes1.new_zeros(rows) + else: + ious = bboxes1.new_zeros((rows, cols)) + ext_module.bbox_overlaps( + bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset) + return ious diff --git a/annotator/uniformer/mmcv/ops/border_align.py b/annotator/uniformer/mmcv/ops/border_align.py new file mode 100644 index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/border_align.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modified from +# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['border_align_forward', 'border_align_backward']) + + +class BorderAlignFunction(Function): + + @staticmethod + def symbolic(g, input, boxes, pool_size): + return g.op( + 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size) + + @staticmethod + def forward(ctx, input, boxes, pool_size): + ctx.pool_size = pool_size + ctx.input_shape = input.size() + + assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]' + assert boxes.size(2) == 4, \ + 'the last dimension of boxes must be (x1, y1, x2, y2)' + assert input.size(1) % 4 == 0, \ + 'the channel for input feature must be divisible by factor 4' + + # [B, C//4, H*W, 4] + output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4) + output = input.new_zeros(output_shape) + # `argmax_idx` only used for backward + argmax_idx = input.new_zeros(output_shape).to(torch.int) + + ext_module.border_align_forward( + input, boxes, output, argmax_idx, pool_size=ctx.pool_size) + + ctx.save_for_backward(boxes, argmax_idx) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + boxes, argmax_idx = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + # complex head architecture may cause grad_output uncontiguous + grad_output = grad_output.contiguous() + ext_module.border_align_backward( + grad_output, + boxes, + argmax_idx, + grad_input, + pool_size=ctx.pool_size) + return grad_input, None, None + + +border_align = BorderAlignFunction.apply + + +class BorderAlign(nn.Module): + r"""Border align pooling layer. + + Applies border_align over the input feature based on predicted bboxes. + The details were described in the paper + `BorderDet: Border Feature for Dense Object Detection + `_. + + For each border line (e.g. top, left, bottom or right) of each box, + border_align does the following: + 1. uniformly samples `pool_size`+1 positions on this line, involving \ + the start and end points. + 2. the corresponding features on these points are computed by \ + bilinear interpolation. + 3. max pooling over all the `pool_size`+1 positions are used for \ + computing pooled feature. + + Args: + pool_size (int): number of positions sampled over the boxes' borders + (e.g. top, bottom, left, right). + + """ + + def __init__(self, pool_size): + super(BorderAlign, self).__init__() + self.pool_size = pool_size + + def forward(self, input, boxes): + """ + Args: + input: Features with shape [N,4C,H,W]. Channels ranged in [0,C), + [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom, + right features respectively. + boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2). + + Returns: + Tensor: Pooled features with shape [N,C,H*W,4]. The order is + (top,left,bottom,right) for the last dimension. + """ + return border_align(input, boxes, self.pool_size) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(pool_size={self.pool_size})' + return s diff --git a/annotator/uniformer/mmcv/ops/box_iou_rotated.py b/annotator/uniformer/mmcv/ops/box_iou_rotated.py new file mode 100644 index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/box_iou_rotated.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated']) + + +def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False): + """Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in + (x_center, y_center, width, height, angle) format. + + If ``aligned`` is ``False``, then calculate the ious between each bbox + of bboxes1 and bboxes2, otherwise the ious between each aligned pair of + bboxes1 and bboxes2. + + Arguments: + boxes1 (Tensor): rotated bboxes 1. \ + It has shape (N, 5), indicating (x, y, w, h, theta) for each row. + Note that theta is in radian. + boxes2 (Tensor): rotated bboxes 2. \ + It has shape (M, 5), indicating (x, y, w, h, theta) for each row. + Note that theta is in radian. + mode (str): "iou" (intersection over union) or iof (intersection over + foreground). + + Returns: + ious(Tensor): shape (N, M) if aligned == False else shape (N,) + """ + assert mode in ['iou', 'iof'] + mode_dict = {'iou': 0, 'iof': 1} + mode_flag = mode_dict[mode] + rows = bboxes1.size(0) + cols = bboxes2.size(0) + if aligned: + ious = bboxes1.new_zeros(rows) + else: + ious = bboxes1.new_zeros((rows * cols)) + bboxes1 = bboxes1.contiguous() + bboxes2 = bboxes2.contiguous() + ext_module.box_iou_rotated( + bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned) + if not aligned: + ious = ious.view(rows, cols) + return ious diff --git a/annotator/uniformer/mmcv/ops/carafe.py b/annotator/uniformer/mmcv/ops/carafe.py new file mode 100644 index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/carafe.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.modules.module import Module + +from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward', + 'carafe_backward' +]) + + +class CARAFENaiveFunction(Function): + + @staticmethod + def symbolic(g, features, masks, kernel_size, group_size, scale_factor): + return g.op( + 'mmcv::MMCVCARAFENaive', + features, + masks, + kernel_size_i=kernel_size, + group_size_i=group_size, + scale_factor_f=scale_factor) + + @staticmethod + def forward(ctx, features, masks, kernel_size, group_size, scale_factor): + assert scale_factor >= 1 + assert masks.size(1) == kernel_size * kernel_size * group_size + assert masks.size(-1) == features.size(-1) * scale_factor + assert masks.size(-2) == features.size(-2) * scale_factor + assert features.size(1) % group_size == 0 + assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 + ctx.kernel_size = kernel_size + ctx.group_size = group_size + ctx.scale_factor = scale_factor + ctx.feature_size = features.size() + ctx.mask_size = masks.size() + + n, c, h, w = features.size() + output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) + ext_module.carafe_naive_forward( + features, + masks, + output, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + + if features.requires_grad or masks.requires_grad: + ctx.save_for_backward(features, masks) + return output + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_cuda + + features, masks = ctx.saved_tensors + kernel_size = ctx.kernel_size + group_size = ctx.group_size + scale_factor = ctx.scale_factor + + grad_input = torch.zeros_like(features) + grad_masks = torch.zeros_like(masks) + ext_module.carafe_naive_backward( + grad_output.contiguous(), + features, + masks, + grad_input, + grad_masks, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + + return grad_input, grad_masks, None, None, None + + +carafe_naive = CARAFENaiveFunction.apply + + +class CARAFENaive(Module): + + def __init__(self, kernel_size, group_size, scale_factor): + super(CARAFENaive, self).__init__() + + assert isinstance(kernel_size, int) and isinstance( + group_size, int) and isinstance(scale_factor, int) + self.kernel_size = kernel_size + self.group_size = group_size + self.scale_factor = scale_factor + + def forward(self, features, masks): + return carafe_naive(features, masks, self.kernel_size, self.group_size, + self.scale_factor) + + +class CARAFEFunction(Function): + + @staticmethod + def symbolic(g, features, masks, kernel_size, group_size, scale_factor): + return g.op( + 'mmcv::MMCVCARAFE', + features, + masks, + kernel_size_i=kernel_size, + group_size_i=group_size, + scale_factor_f=scale_factor) + + @staticmethod + def forward(ctx, features, masks, kernel_size, group_size, scale_factor): + assert scale_factor >= 1 + assert masks.size(1) == kernel_size * kernel_size * group_size + assert masks.size(-1) == features.size(-1) * scale_factor + assert masks.size(-2) == features.size(-2) * scale_factor + assert features.size(1) % group_size == 0 + assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 + ctx.kernel_size = kernel_size + ctx.group_size = group_size + ctx.scale_factor = scale_factor + ctx.feature_size = features.size() + ctx.mask_size = masks.size() + + n, c, h, w = features.size() + output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) + routput = features.new_zeros(output.size(), requires_grad=False) + rfeatures = features.new_zeros(features.size(), requires_grad=False) + rmasks = masks.new_zeros(masks.size(), requires_grad=False) + ext_module.carafe_forward( + features, + masks, + rfeatures, + routput, + rmasks, + output, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + + if features.requires_grad or masks.requires_grad: + ctx.save_for_backward(features, masks, rfeatures) + return output + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_cuda + + features, masks, rfeatures = ctx.saved_tensors + kernel_size = ctx.kernel_size + group_size = ctx.group_size + scale_factor = ctx.scale_factor + + rgrad_output = torch.zeros_like(grad_output, requires_grad=False) + rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False) + rgrad_input = torch.zeros_like(features, requires_grad=False) + rgrad_masks = torch.zeros_like(masks, requires_grad=False) + grad_input = torch.zeros_like(features, requires_grad=False) + grad_masks = torch.zeros_like(masks, requires_grad=False) + ext_module.carafe_backward( + grad_output.contiguous(), + rfeatures, + masks, + rgrad_output, + rgrad_input_hs, + rgrad_input, + rgrad_masks, + grad_input, + grad_masks, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + return grad_input, grad_masks, None, None, None + + +carafe = CARAFEFunction.apply + + +class CARAFE(Module): + """ CARAFE: Content-Aware ReAssembly of FEatures + + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + kernel_size (int): reassemble kernel size + group_size (int): reassemble group size + scale_factor (int): upsample ratio + + Returns: + upsampled feature map + """ + + def __init__(self, kernel_size, group_size, scale_factor): + super(CARAFE, self).__init__() + + assert isinstance(kernel_size, int) and isinstance( + group_size, int) and isinstance(scale_factor, int) + self.kernel_size = kernel_size + self.group_size = group_size + self.scale_factor = scale_factor + + def forward(self, features, masks): + return carafe(features, masks, self.kernel_size, self.group_size, + self.scale_factor) + + +@UPSAMPLE_LAYERS.register_module(name='carafe') +class CARAFEPack(nn.Module): + """A unified package of CARAFE upsampler that contains: 1) channel + compressor 2) content encoder 3) CARAFE op. + + Official implementation of ICCV 2019 paper + CARAFE: Content-Aware ReAssembly of FEatures + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + channels (int): input feature channels + scale_factor (int): upsample ratio + up_kernel (int): kernel size of CARAFE op + up_group (int): group size of CARAFE op + encoder_kernel (int): kernel size of content encoder + encoder_dilation (int): dilation of content encoder + compressed_channels (int): output channels of channels compressor + + Returns: + upsampled feature map + """ + + def __init__(self, + channels, + scale_factor, + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1, + compressed_channels=64): + super(CARAFEPack, self).__init__() + self.channels = channels + self.scale_factor = scale_factor + self.up_kernel = up_kernel + self.up_group = up_group + self.encoder_kernel = encoder_kernel + self.encoder_dilation = encoder_dilation + self.compressed_channels = compressed_channels + self.channel_compressor = nn.Conv2d(channels, self.compressed_channels, + 1) + self.content_encoder = nn.Conv2d( + self.compressed_channels, + self.up_kernel * self.up_kernel * self.up_group * + self.scale_factor * self.scale_factor, + self.encoder_kernel, + padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2), + dilation=self.encoder_dilation, + groups=1) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + normal_init(self.content_encoder, std=0.001) + + def kernel_normalizer(self, mask): + mask = F.pixel_shuffle(mask, self.scale_factor) + n, mask_c, h, w = mask.size() + # use float division explicitly, + # to void inconsistency while exporting to onnx + mask_channel = int(mask_c / float(self.up_kernel**2)) + mask = mask.view(n, mask_channel, -1, h, w) + + mask = F.softmax(mask, dim=2, dtype=mask.dtype) + mask = mask.view(n, mask_c, h, w).contiguous() + + return mask + + def feature_reassemble(self, x, mask): + x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor) + return x + + def forward(self, x): + compressed_x = self.channel_compressor(x) + mask = self.content_encoder(compressed_x) + mask = self.kernel_normalizer(mask) + + x = self.feature_reassemble(x, mask) + return x diff --git a/annotator/uniformer/mmcv/ops/cc_attention.py b/annotator/uniformer/mmcv/ops/cc_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9207aa95e6730bd9b3362dee612059a5f0ce1c5e --- /dev/null +++ b/annotator/uniformer/mmcv/ops/cc_attention.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmcv.cnn import PLUGIN_LAYERS, Scale + + +def NEG_INF_DIAG(n, device): + """Returns a diagonal matrix of size [n, n]. + + The diagonal are all "-inf". This is for avoiding calculating the + overlapped element in the Criss-Cross twice. + """ + return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0) + + +@PLUGIN_LAYERS.register_module() +class CrissCrossAttention(nn.Module): + """Criss-Cross Attention Module. + + .. note:: + Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch + to a pure PyTorch and equivalent implementation. For more + details, please refer to https://github.com/open-mmlab/mmcv/pull/1201. + + Speed comparison for one forward pass + + - Input size: [2,512,97,97] + - Device: 1 NVIDIA GeForce RTX 2080 Ti + + +-----------------------+---------------+------------+---------------+ + | |PyTorch version|CUDA version|Relative speed | + +=======================+===============+============+===============+ + |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x | + +-----------------------+---------------+------------+---------------+ + |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x | + +-----------------------+---------------+------------+---------------+ + + Args: + in_channels (int): Channels of the input feature map. + """ + + def __init__(self, in_channels): + super().__init__() + self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) + self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) + self.value_conv = nn.Conv2d(in_channels, in_channels, 1) + self.gamma = Scale(0.) + self.in_channels = in_channels + + def forward(self, x): + """forward function of Criss-Cross Attention. + + Args: + x (Tensor): Input feature. \ + shape (batch_size, in_channels, height, width) + Returns: + Tensor: Output of the layer, with shape of \ + (batch_size, in_channels, height, width) + """ + B, C, H, W = x.size() + query = self.query_conv(x) + key = self.key_conv(x) + value = self.value_conv(x) + energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG( + H, query.device) + energy_H = energy_H.transpose(1, 2) + energy_W = torch.einsum('bchw,bchj->bhwj', query, key) + attn = F.softmax( + torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)] + out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H]) + out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:]) + + out = self.gamma(out) + x + out = out.contiguous() + + return out + + def __repr__(self): + s = self.__class__.__name__ + s += f'(in_channels={self.in_channels})' + return s diff --git a/annotator/uniformer/mmcv/ops/contour_expand.py b/annotator/uniformer/mmcv/ops/contour_expand.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/contour_expand.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['contour_expand']) + + +def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area, + kernel_num): + """Expand kernel contours so that foreground pixels are assigned into + instances. + + Arguments: + kernel_mask (np.array or Tensor): The instance kernel mask with + size hxw. + internal_kernel_label (np.array or Tensor): The instance internal + kernel label with size hxw. + min_kernel_area (int): The minimum kernel area. + kernel_num (int): The instance kernel number. + + Returns: + label (list): The instance index map with size hxw. + """ + assert isinstance(kernel_mask, (torch.Tensor, np.ndarray)) + assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray)) + assert isinstance(min_kernel_area, int) + assert isinstance(kernel_num, int) + + if isinstance(kernel_mask, np.ndarray): + kernel_mask = torch.from_numpy(kernel_mask) + if isinstance(internal_kernel_label, np.ndarray): + internal_kernel_label = torch.from_numpy(internal_kernel_label) + + if torch.__version__ == 'parrots': + if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0: + label = [] + else: + label = ext_module.contour_expand( + kernel_mask, + internal_kernel_label, + min_kernel_area=min_kernel_area, + kernel_num=kernel_num) + label = label.tolist() + else: + label = ext_module.contour_expand(kernel_mask, internal_kernel_label, + min_kernel_area, kernel_num) + return label diff --git a/annotator/uniformer/mmcv/ops/corner_pool.py b/annotator/uniformer/mmcv/ops/corner_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/corner_pool.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward', + 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward', + 'right_pool_forward', 'right_pool_backward' +]) + +_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} + + +class TopPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.top_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.top_pool_backward(input, grad_output) + return output + + +class BottomPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.bottom_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.bottom_pool_backward(input, grad_output) + return output + + +class LeftPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.left_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.left_pool_backward(input, grad_output) + return output + + +class RightPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.right_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.right_pool_backward(input, grad_output) + return output + + +class CornerPool(nn.Module): + """Corner Pooling. + + Corner Pooling is a new type of pooling layer that helps a + convolutional network better localize corners of bounding boxes. + + Please refer to https://arxiv.org/abs/1808.01244 for more details. + Code is modified from https://github.com/princeton-vl/CornerNet-Lite. + + Args: + mode(str): Pooling orientation for the pooling layer + + - 'bottom': Bottom Pooling + - 'left': Left Pooling + - 'right': Right Pooling + - 'top': Top Pooling + + Returns: + Feature map after pooling. + """ + + pool_functions = { + 'bottom': BottomPoolFunction, + 'left': LeftPoolFunction, + 'right': RightPoolFunction, + 'top': TopPoolFunction, + } + + cummax_dim_flip = { + 'bottom': (2, False), + 'left': (3, True), + 'right': (3, False), + 'top': (2, True), + } + + def __init__(self, mode): + super(CornerPool, self).__init__() + assert mode in self.pool_functions + self.mode = mode + self.corner_pool = self.pool_functions[mode] + + def forward(self, x): + if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': + if torch.onnx.is_in_onnx_export(): + assert torch.__version__ >= '1.7.0', \ + 'When `cummax` serves as an intermediate component whose '\ + 'outputs is used as inputs for another modules, it\'s '\ + 'expected that pytorch version must be >= 1.7.0, '\ + 'otherwise Error appears like: `RuntimeError: tuple '\ + 'appears in op that does not forward tuples, unsupported '\ + 'kind: prim::PythonOp`.' + + dim, flip = self.cummax_dim_flip[self.mode] + if flip: + x = x.flip(dim) + pool_tensor, _ = torch.cummax(x, dim=dim) + if flip: + pool_tensor = pool_tensor.flip(dim) + return pool_tensor + else: + return self.corner_pool.apply(x) diff --git a/annotator/uniformer/mmcv/ops/correlation.py b/annotator/uniformer/mmcv/ops/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/correlation.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['correlation_forward', 'correlation_backward']) + + +class CorrelationFunction(Function): + + @staticmethod + def forward(ctx, + input1, + input2, + kernel_size=1, + max_displacement=1, + stride=1, + padding=1, + dilation=1, + dilation_patch=1): + + ctx.save_for_backward(input1, input2) + + kH, kW = ctx.kernel_size = _pair(kernel_size) + patch_size = max_displacement * 2 + 1 + ctx.patch_size = patch_size + dH, dW = ctx.stride = _pair(stride) + padH, padW = ctx.padding = _pair(padding) + dilationH, dilationW = ctx.dilation = _pair(dilation) + dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair( + dilation_patch) + + output_size = CorrelationFunction._output_size(ctx, input1) + + output = input1.new_zeros(output_size) + + ext_module.correlation_forward( + input1, + input2, + output, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input1, input2 = ctx.saved_tensors + + kH, kW = ctx.kernel_size + patch_size = ctx.patch_size + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilation_patchH, dilation_patchW = ctx.dilation_patch + dH, dW = ctx.stride + grad_input1 = torch.zeros_like(input1) + grad_input2 = torch.zeros_like(input2) + + ext_module.correlation_backward( + grad_output, + input1, + input2, + grad_input1, + grad_input2, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) + return grad_input1, grad_input2, None, None, None, None, None, None + + @staticmethod + def _output_size(ctx, input1): + iH, iW = input1.size(2), input1.size(3) + batch_size = input1.size(0) + kH, kW = ctx.kernel_size + patch_size = ctx.patch_size + dH, dW = ctx.stride + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilatedKH = (kH - 1) * dilationH + 1 + dilatedKW = (kW - 1) * dilationW + 1 + + oH = int((iH + 2 * padH - dilatedKH) / dH + 1) + oW = int((iW + 2 * padW - dilatedKW) / dW + 1) + + output_size = (batch_size, patch_size, patch_size, oH, oW) + return output_size + + +class Correlation(nn.Module): + r"""Correlation operator + + This correlation operator works for optical flow correlation computation. + + There are two batched tensors with shape :math:`(N, C, H, W)`, + and the correlation output's shape is :math:`(N, max\_displacement \times + 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})` + + where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding - + dilation \times (kernel\_size - 1) - 1} + {stride} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation + \times (kernel\_size - 1) - 1} + {stride} + 1\right\rfloor + + the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding + window convolution between input1 and shifted input2, + + .. math:: + Corr(N_i, dx, dy) = + \sum_{c=0}^{C-1} + input1(N_i, c) \star + \mathcal{S}(input2(N_i, c), dy, dx) + + where :math:`\star` is the valid 2d sliding window convolution operator, + and :math:`\mathcal{S}` means shifting the input features (auto-complete + zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in + [-max\_displacement \times dilation\_patch, max\_displacement \times + dilation\_patch]`. + + Args: + kernel_size (int): The size of sliding window i.e. local neighborhood + representing the center points and involved in correlation + computation. Defaults to 1. + max_displacement (int): The radius for computing correlation volume, + but the actual working space can be dilated by dilation_patch. + Defaults to 1. + stride (int): The stride of the sliding blocks in the input spatial + dimensions. Defaults to 1. + padding (int): Zero padding added to all four sides of the input1. + Defaults to 0. + dilation (int): The spacing of local neighborhood that will involved + in correlation. Defaults to 1. + dilation_patch (int): The spacing between position need to compute + correlation. Defaults to 1. + """ + + def __init__(self, + kernel_size: int = 1, + max_displacement: int = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + dilation_patch: int = 1) -> None: + super().__init__() + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride = stride + self.padding = padding + self.dilation = dilation + self.dilation_patch = dilation_patch + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + return CorrelationFunction.apply(input1, input2, self.kernel_size, + self.max_displacement, self.stride, + self.padding, self.dilation, + self.dilation_patch) + + def __repr__(self) -> str: + s = self.__class__.__name__ + s += f'(kernel_size={self.kernel_size}, ' + s += f'max_displacement={self.max_displacement}, ' + s += f'stride={self.stride}, ' + s += f'padding={self.padding}, ' + s += f'dilation={self.dilation}, ' + s += f'dilation_patch={self.dilation_patch})' + return s diff --git a/annotator/uniformer/mmcv/ops/deform_conv.py b/annotator/uniformer/mmcv/ops/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f8c75ee774823eea334e3b3732af6a18f55038 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/deform_conv.py @@ -0,0 +1,405 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair, _single + +from annotator.uniformer.mmcv.utils import deprecated_api_warning +from ..cnn import CONV_LAYERS +from ..utils import ext_loader, print_log + +ext_module = ext_loader.load_ext('_ext', [ + 'deform_conv_forward', 'deform_conv_backward_input', + 'deform_conv_backward_parameters' +]) + + +class DeformConv2dFunction(Function): + + @staticmethod + def symbolic(g, + input, + offset, + weight, + stride, + padding, + dilation, + groups, + deform_groups, + bias=False, + im2col_step=32): + return g.op( + 'mmcv::MMCVDeformConv2d', + input, + offset, + weight, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups, + bias_i=bias, + im2col_step_i=im2col_step) + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=False, + im2col_step=32): + if input is not None and input.dim() != 4: + raise ValueError( + f'Expected 4D tensor as input, got {input.dim()}D tensor \ + instead.') + assert bias is False, 'Only support bias is False.' + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deform_groups = deform_groups + ctx.im2col_step = im2col_step + + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConv2dFunction._output_size(ctx, input, weight)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + cur_im2col_step = min(ctx.im2col_step, input.size(0)) + assert (input.size(0) % + cur_im2col_step) == 0, 'im2col step must divide batchsize' + ext_module.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + im2col_step=cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + cur_im2col_step = min(ctx.im2col_step, input.size(0)) + assert (input.size(0) % cur_im2col_step + ) == 0, 'batch size must be divisible by im2col_step' + + grad_output = grad_output.contiguous() + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + ext_module.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + im2col_step=cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + ext_module.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + scale=1, + im2col_step=cur_im2col_step) + + return grad_input, grad_offset, grad_weight, \ + None, None, None, None, None, None, None + + @staticmethod + def _output_size(ctx, input, weight): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = ctx.padding[d] + kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = ctx.stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + 'convolution input is too small (output would be ' + + 'x'.join(map(str, output_size)) + ')') + return output_size + + +deform_conv2d = DeformConv2dFunction.apply + + +class DeformConv2d(nn.Module): + r"""Deformable 2D convolution. + + Applies a deformable 2D convolution over an input signal composed of + several input planes. DeformConv2d was described in the paper + `Deformable Convolutional Networks + `_ + + Note: + The argument ``im2col_step`` was added in version 1.3.17, which means + number of samples processed by the ``im2col_cuda_kernel`` per call. + It enables users to define ``batch_size`` and ``im2col_step`` more + flexibly and solved `issue mmcv#1440 + `_. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size(int, tuple): Size of the convolving kernel. + stride(int, tuple): Stride of the convolution. Default: 1. + padding (int or tuple): Zero-padding added to both sides of the input. + Default: 0. + dilation (int or tuple): Spacing between kernel elements. Default: 1. + groups (int): Number of blocked connections from input. + channels to output channels. Default: 1. + deform_groups (int): Number of deformable group partitions. + bias (bool): If True, adds a learnable bias to the output. + Default: False. + im2col_step (int): Number of samples processed by im2col_cuda_kernel + per call. It will work when ``batch_size`` > ``im2col_step``, but + ``batch_size`` must be divisible by ``im2col_step``. Default: 32. + `New in version 1.3.17.` + """ + + @deprecated_api_warning({'deformable_groups': 'deform_groups'}, + cls_name='DeformConv2d') + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + deform_groups: int = 1, + bias: bool = False, + im2col_step: int = 32) -> None: + super(DeformConv2d, self).__init__() + + assert not bias, \ + f'bias={bias} is not supported in DeformConv2d.' + assert in_channels % groups == 0, \ + f'in_channels {in_channels} cannot be divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} cannot be divisible by groups \ + {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + self.im2col_step = im2col_step + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + # only weight, no bias + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, + *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + # switch the initialization of `self.weight` to the standard kaiming + # method described in `Delving deep into rectifiers: Surpassing + # human-level performance on ImageNet classification` - He, K. et al. + # (2015), using a uniform distribution + nn.init.kaiming_uniform_(self.weight, nonlinearity='relu') + + def forward(self, x: Tensor, offset: Tensor) -> Tensor: + """Deformable Convolutional forward function. + + Args: + x (Tensor): Input feature, shape (B, C_in, H_in, W_in) + offset (Tensor): Offset for deformable convolution, shape + (B, deform_groups*kernel_size[0]*kernel_size[1]*2, + H_out, W_out), H_out, W_out are equal to the output's. + + An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. + The spatial arrangement is like: + + .. code:: text + + (x0, y0) (x1, y1) (x2, y2) + (x3, y3) (x4, y4) (x5, y5) + (x6, y6) (x7, y7) (x8, y8) + + Returns: + Tensor: Output of the layer. + """ + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) < + self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0) + offset = offset.contiguous() + out = deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups, + False, self.im2col_step) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - + pad_w].contiguous() + return out + + def __repr__(self): + s = self.__class__.__name__ + s += f'(in_channels={self.in_channels},\n' + s += f'out_channels={self.out_channels},\n' + s += f'kernel_size={self.kernel_size},\n' + s += f'stride={self.stride},\n' + s += f'padding={self.padding},\n' + s += f'dilation={self.dilation},\n' + s += f'groups={self.groups},\n' + s += f'deform_groups={self.deform_groups},\n' + # bias is not supported in DeformConv2d. + s += 'bias=False)' + return s + + +@CONV_LAYERS.register_module('DCN') +class DeformConv2dPack(DeformConv2d): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. + The spatial arrangement is like: + + .. code:: text + + (x0, y0) (x1, y1) (x2, y2) + (x3, y3) (x4, y4) (x5, y5) + (x6, y6) (x7, y7) (x8, y8) + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConv2dPack, self).__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups, + False, self.im2col_step) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, DeformConvPack loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + print_log( + f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to ' + 'version 2.', + logger='root') + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/annotator/uniformer/mmcv/ops/deform_roi_pool.py b/annotator/uniformer/mmcv/ops/deform_roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b --- /dev/null +++ b/annotator/uniformer/mmcv/ops/deform_roi_pool.py @@ -0,0 +1,204 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward']) + + +class DeformRoIPoolFunction(Function): + + @staticmethod + def symbolic(g, input, rois, offset, output_size, spatial_scale, + sampling_ratio, gamma): + return g.op( + 'mmcv::MMCVDeformRoIPool', + input, + rois, + offset, + pooled_height_i=output_size[0], + pooled_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_f=sampling_ratio, + gamma_f=gamma) + + @staticmethod + def forward(ctx, + input, + rois, + offset, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + if offset is None: + offset = input.new_zeros(0) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = float(spatial_scale) + ctx.sampling_ratio = int(sampling_ratio) + ctx.gamma = float(gamma) + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + + ext_module.deform_roi_pool_forward( + input, + rois, + offset, + output, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + gamma=ctx.gamma) + + ctx.save_for_backward(input, rois, offset) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, rois, offset = ctx.saved_tensors + grad_input = grad_output.new_zeros(input.shape) + grad_offset = grad_output.new_zeros(offset.shape) + + ext_module.deform_roi_pool_backward( + grad_output, + input, + rois, + offset, + grad_input, + grad_offset, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + gamma=ctx.gamma) + if grad_offset.numel() == 0: + grad_offset = None + return grad_input, None, grad_offset, None, None, None, None + + +deform_roi_pool = DeformRoIPoolFunction.apply + + +class DeformRoIPool(nn.Module): + + def __init__(self, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + super(DeformRoIPool, self).__init__() + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.gamma = float(gamma) + + def forward(self, input, rois, offset=None): + return deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + + +class DeformRoIPoolPack(DeformRoIPool): + + def __init__(self, + output_size, + output_channels, + deform_fc_channels=1024, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale, + sampling_ratio, gamma) + + self.output_channels = output_channels + self.deform_fc_channels = deform_fc_channels + + self.offset_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + def forward(self, input, rois): + assert input.size(1) == self.output_channels + x = deform_roi_pool(input, rois, None, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + rois_num = rois.size(0) + offset = self.offset_fc(x.view(rois_num, -1)) + offset = offset.view(rois_num, 2, self.output_size[0], + self.output_size[1]) + return deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + + +class ModulatedDeformRoIPoolPack(DeformRoIPool): + + def __init__(self, + output_size, + output_channels, + deform_fc_channels=1024, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + super(ModulatedDeformRoIPoolPack, + self).__init__(output_size, spatial_scale, sampling_ratio, gamma) + + self.output_channels = output_channels + self.deform_fc_channels = deform_fc_channels + + self.offset_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + self.mask_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 1), + nn.Sigmoid()) + self.mask_fc[2].weight.data.zero_() + self.mask_fc[2].bias.data.zero_() + + def forward(self, input, rois): + assert input.size(1) == self.output_channels + x = deform_roi_pool(input, rois, None, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + rois_num = rois.size(0) + offset = self.offset_fc(x.view(rois_num, -1)) + offset = offset.view(rois_num, 2, self.output_size[0], + self.output_size[1]) + mask = self.mask_fc(x.view(rois_num, -1)) + mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1]) + d = deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + return d * mask diff --git a/annotator/uniformer/mmcv/ops/deprecated_wrappers.py b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This file is for backward compatibility. +# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks. +import warnings + +from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d + + +class Conv2d_deprecated(Conv2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in' + ' the future. Please import them from "mmcv.cnn" instead') + + +class ConvTranspose2d_deprecated(ConvTranspose2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be ' + 'deprecated in the future. Please import them from "mmcv.cnn" ' + 'instead') + + +class MaxPool2d_deprecated(MaxPool2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in' + ' the future. Please import them from "mmcv.cnn" instead') + + +class Linear_deprecated(Linear): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing Linear wrapper from "mmcv.ops" will be deprecated in' + ' the future. Please import them from "mmcv.cnn" instead') diff --git a/annotator/uniformer/mmcv/ops/focal_loss.py b/annotator/uniformer/mmcv/ops/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/focal_loss.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward', + 'softmax_focal_loss_forward', 'softmax_focal_loss_backward' +]) + + +class SigmoidFocalLossFunction(Function): + + @staticmethod + def symbolic(g, input, target, gamma, alpha, weight, reduction): + return g.op( + 'mmcv::MMCVSigmoidFocalLoss', + input, + target, + gamma_f=gamma, + alpha_f=alpha, + weight_f=weight, + reduction_s=reduction) + + @staticmethod + def forward(ctx, + input, + target, + gamma=2.0, + alpha=0.25, + weight=None, + reduction='mean'): + + assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert input.dim() == 2 + assert target.dim() == 1 + assert input.size(0) == target.size(0) + if weight is None: + weight = input.new_empty(0) + else: + assert weight.dim() == 1 + assert input.size(1) == weight.size(0) + ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} + assert reduction in ctx.reduction_dict.keys() + + ctx.gamma = float(gamma) + ctx.alpha = float(alpha) + ctx.reduction = ctx.reduction_dict[reduction] + + output = input.new_zeros(input.size()) + + ext_module.sigmoid_focal_loss_forward( + input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) + if ctx.reduction == ctx.reduction_dict['mean']: + output = output.sum() / input.size(0) + elif ctx.reduction == ctx.reduction_dict['sum']: + output = output.sum() + ctx.save_for_backward(input, target, weight) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, target, weight = ctx.saved_tensors + + grad_input = input.new_zeros(input.size()) + + ext_module.sigmoid_focal_loss_backward( + input, + target, + weight, + grad_input, + gamma=ctx.gamma, + alpha=ctx.alpha) + + grad_input *= grad_output + if ctx.reduction == ctx.reduction_dict['mean']: + grad_input /= input.size(0) + return grad_input, None, None, None, None, None + + +sigmoid_focal_loss = SigmoidFocalLossFunction.apply + + +class SigmoidFocalLoss(nn.Module): + + def __init__(self, gamma, alpha, weight=None, reduction='mean'): + super(SigmoidFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.register_buffer('weight', weight) + self.reduction = reduction + + def forward(self, input, target): + return sigmoid_focal_loss(input, target, self.gamma, self.alpha, + self.weight, self.reduction) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(gamma={self.gamma}, ' + s += f'alpha={self.alpha}, ' + s += f'reduction={self.reduction})' + return s + + +class SoftmaxFocalLossFunction(Function): + + @staticmethod + def symbolic(g, input, target, gamma, alpha, weight, reduction): + return g.op( + 'mmcv::MMCVSoftmaxFocalLoss', + input, + target, + gamma_f=gamma, + alpha_f=alpha, + weight_f=weight, + reduction_s=reduction) + + @staticmethod + def forward(ctx, + input, + target, + gamma=2.0, + alpha=0.25, + weight=None, + reduction='mean'): + + assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert input.dim() == 2 + assert target.dim() == 1 + assert input.size(0) == target.size(0) + if weight is None: + weight = input.new_empty(0) + else: + assert weight.dim() == 1 + assert input.size(1) == weight.size(0) + ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} + assert reduction in ctx.reduction_dict.keys() + + ctx.gamma = float(gamma) + ctx.alpha = float(alpha) + ctx.reduction = ctx.reduction_dict[reduction] + + channel_stats, _ = torch.max(input, dim=1) + input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) + input_softmax.exp_() + + channel_stats = input_softmax.sum(dim=1) + input_softmax /= channel_stats.unsqueeze(1).expand_as(input) + + output = input.new_zeros(input.size(0)) + ext_module.softmax_focal_loss_forward( + input_softmax, + target, + weight, + output, + gamma=ctx.gamma, + alpha=ctx.alpha) + + if ctx.reduction == ctx.reduction_dict['mean']: + output = output.sum() / input.size(0) + elif ctx.reduction == ctx.reduction_dict['sum']: + output = output.sum() + ctx.save_for_backward(input_softmax, target, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input_softmax, target, weight = ctx.saved_tensors + buff = input_softmax.new_zeros(input_softmax.size(0)) + grad_input = input_softmax.new_zeros(input_softmax.size()) + + ext_module.softmax_focal_loss_backward( + input_softmax, + target, + weight, + buff, + grad_input, + gamma=ctx.gamma, + alpha=ctx.alpha) + + grad_input *= grad_output + if ctx.reduction == ctx.reduction_dict['mean']: + grad_input /= input_softmax.size(0) + return grad_input, None, None, None, None, None + + +softmax_focal_loss = SoftmaxFocalLossFunction.apply + + +class SoftmaxFocalLoss(nn.Module): + + def __init__(self, gamma, alpha, weight=None, reduction='mean'): + super(SoftmaxFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.register_buffer('weight', weight) + self.reduction = reduction + + def forward(self, input, target): + return softmax_focal_loss(input, target, self.gamma, self.alpha, + self.weight, self.reduction) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(gamma={self.gamma}, ' + s += f'alpha={self.alpha}, ' + s += f'reduction={self.reduction})' + return s diff --git a/annotator/uniformer/mmcv/ops/furthest_point_sample.py b/annotator/uniformer/mmcv/ops/furthest_point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f --- /dev/null +++ b/annotator/uniformer/mmcv/ops/furthest_point_sample.py @@ -0,0 +1,83 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'furthest_point_sampling_forward', + 'furthest_point_sampling_with_dist_forward' +]) + + +class FurthestPointSampling(Function): + """Uses iterative furthest point sampling to select a set of features whose + corresponding points have the furthest distance.""" + + @staticmethod + def forward(ctx, points_xyz: torch.Tensor, + num_points: int) -> torch.Tensor: + """ + Args: + points_xyz (Tensor): (B, N, 3) where N > num_points. + num_points (int): Number of points in the sampled set. + + Returns: + Tensor: (B, num_points) indices of the sampled points. + """ + assert points_xyz.is_contiguous() + + B, N = points_xyz.size()[:2] + output = torch.cuda.IntTensor(B, num_points) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + ext_module.furthest_point_sampling_forward( + points_xyz, + temp, + output, + b=B, + n=N, + m=num_points, + ) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +class FurthestPointSamplingWithDist(Function): + """Uses iterative furthest point sampling to select a set of features whose + corresponding points have the furthest distance.""" + + @staticmethod + def forward(ctx, points_dist: torch.Tensor, + num_points: int) -> torch.Tensor: + """ + Args: + points_dist (Tensor): (B, N, N) Distance between each point pair. + num_points (int): Number of points in the sampled set. + + Returns: + Tensor: (B, num_points) indices of the sampled points. + """ + assert points_dist.is_contiguous() + + B, N, _ = points_dist.size() + output = points_dist.new_zeros([B, num_points], dtype=torch.int32) + temp = points_dist.new_zeros([B, N]).fill_(1e10) + + ext_module.furthest_point_sampling_with_dist_forward( + points_dist, temp, output, b=B, n=N, m=num_points) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply +furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply diff --git a/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py new file mode 100644 index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa --- /dev/null +++ b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py @@ -0,0 +1,268 @@ +# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator +# Augmentation (ADA) +# ======================================================================= + +# 1. Definitions + +# "Licensor" means any person or entity that distributes its Work. + +# "Software" means the original work of authorship made available under +# this License. + +# "Work" means the Software and any additions to or derivative works of +# the Software that are made available under this License. + +# The terms "reproduce," "reproduction," "derivative works," and +# "distribution" have the meaning as provided under U.S. copyright law; +# provided, however, that for the purposes of this License, derivative +# works shall not include works that remain separable from, or merely +# link (or bind by name) to the interfaces of, the Work. + +# Works, including the Software, are "made available" under this License +# by including in or with the Work either (a) a copyright notice +# referencing the applicability of this License to the Work, or (b) a +# copy of this License. + +# 2. License Grants + +# 2.1 Copyright Grant. Subject to the terms and conditions of this +# License, each Licensor grants to you a perpetual, worldwide, +# non-exclusive, royalty-free, copyright license to reproduce, +# prepare derivative works of, publicly display, publicly perform, +# sublicense and distribute its Work and any resulting derivative +# works in any form. + +# 3. Limitations + +# 3.1 Redistribution. You may reproduce or distribute the Work only +# if (a) you do so under this License, (b) you include a complete +# copy of this License with your distribution, and (c) you retain +# without modification any copyright, patent, trademark, or +# attribution notices that are present in the Work. + +# 3.2 Derivative Works. You may specify that additional or different +# terms apply to the use, reproduction, and distribution of your +# derivative works of the Work ("Your Terms") only if (a) Your Terms +# provide that the use limitation in Section 3.3 applies to your +# derivative works, and (b) you identify the specific derivative +# works that are subject to Your Terms. Notwithstanding Your Terms, +# this License (including the redistribution requirements in Section +# 3.1) will continue to apply to the Work itself. + +# 3.3 Use Limitation. The Work and any derivative works thereof only +# may be used or intended for use non-commercially. Notwithstanding +# the foregoing, NVIDIA and its affiliates may use the Work and any +# derivative works commercially. As used herein, "non-commercially" +# means for research or evaluation purposes only. + +# 3.4 Patent Claims. If you bring or threaten to bring a patent claim +# against any Licensor (including any claim, cross-claim or +# counterclaim in a lawsuit) to enforce any patents that you allege +# are infringed by any Work, then your rights under this License from +# such Licensor (including the grant in Section 2.1) will terminate +# immediately. + +# 3.5 Trademarks. This License does not grant any rights to use any +# Licensor’s or its affiliates’ names, logos, or trademarks, except +# as necessary to reproduce the notices described in this License. + +# 3.6 Termination. If you violate any term of this License, then your +# rights under this License (including the grant in Section 2.1) will +# terminate immediately. + +# 4. Disclaimer of Warranty. + +# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +# THIS LICENSE. + +# 5. Limitation of Liability. + +# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGES. + +# ======================================================================= + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu']) + + +class FusedBiasLeakyReLUFunctionBackward(Function): + """Calculate second order deviation. + + This function is to compute the second order deviation for the fused leaky + relu operation. + """ + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = ext_module.fused_bias_leakyrelu( + grad_output, + empty, + out, + act=3, + grad=1, + alpha=negative_slope, + scale=scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + + # The second order deviation, in fact, contains two parts, while the + # the first part is zero. Thus, we direct consider the second part + # which is similar with the first order deviation in implementation. + gradgrad_out = ext_module.fused_bias_leakyrelu( + gradgrad_input, + gradgrad_bias.to(out.dtype), + out, + act=3, + grad=1, + alpha=ctx.negative_slope, + scale=ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedBiasLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + + out = ext_module.fused_bias_leakyrelu( + input, + bias, + empty, + act=3, + grad=0, + alpha=negative_slope, + scale=scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedBiasLeakyReLU(nn.Module): + """Fused bias leaky ReLU. + + This function is introduced in the StyleGAN2: + http://arxiv.org/abs/1912.04958 + + The bias term comes from the convolution operation. In addition, to keep + the variance of the feature map or gradients unchanged, they also adopt a + scale similarly with Kaiming initialization. However, since the + :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the + final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 + your own scale. + + TODO: Implement the CPU version. + + Args: + channel (int): The channel number of the feature map. + negative_slope (float, optional): Same as nn.LeakyRelu. + Defaults to 0.2. + scale (float, optional): A scalar to adjust the variance of the feature + map. Defaults to 2**0.5. + """ + + def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5): + super(FusedBiasLeakyReLU, self).__init__() + + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_bias_leakyrelu(input, self.bias, self.negative_slope, + self.scale) + + +def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): + """Fused bias leaky ReLU function. + + This function is introduced in the StyleGAN2: + http://arxiv.org/abs/1912.04958 + + The bias term comes from the convolution operation. In addition, to keep + the variance of the feature map or gradients unchanged, they also adopt a + scale similarly with Kaiming initialization. However, since the + :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the + final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 + your own scale. + + Args: + input (torch.Tensor): Input feature map. + bias (nn.Parameter): The bias from convolution operation. + negative_slope (float, optional): Same as nn.LeakyRelu. + Defaults to 0.2. + scale (float, optional): A scalar to adjust the variance of the feature + map. Defaults to 2**0.5. + + Returns: + torch.Tensor: Feature map after non-linear activation. + """ + + if not input.is_cuda: + return bias_leakyrelu_ref(input, bias, negative_slope, scale) + + return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), + negative_slope, scale) + + +def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5): + + if bias is not None: + assert bias.ndim == 1 + assert bias.shape[0] == x.shape[1] + x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) + + x = F.leaky_relu(x, negative_slope) + if scale != 1: + x = x * scale + + return x diff --git a/annotator/uniformer/mmcv/ops/gather_points.py b/annotator/uniformer/mmcv/ops/gather_points.py new file mode 100644 index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/gather_points.py @@ -0,0 +1,57 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['gather_points_forward', 'gather_points_backward']) + + +class GatherPoints(Function): + """Gather points with given index.""" + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, N) features to gather. + indices (Tensor): (B, M) where M is the number of points. + + Returns: + Tensor: (B, C, M) where M is the number of points. + """ + assert features.is_contiguous() + assert indices.is_contiguous() + + B, npoint = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, npoint) + + ext_module.gather_points_forward( + features, indices, output, b=B, c=C, n=N, npoints=npoint) + + ctx.for_backwards = (indices, C, N) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(indices) + return output + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + grad_out_data = grad_out.data.contiguous() + ext_module.gather_points_backward( + grad_out_data, + idx, + grad_features.data, + b=B, + c=C, + n=N, + npoints=npoint) + return grad_features, None + + +gather_points = GatherPoints.apply diff --git a/annotator/uniformer/mmcv/ops/group_points.py b/annotator/uniformer/mmcv/ops/group_points.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/group_points.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader +from .ball_query import ball_query +from .knn import knn + +ext_module = ext_loader.load_ext( + '_ext', ['group_points_forward', 'group_points_backward']) + + +class QueryAndGroup(nn.Module): + """Groups points with a ball query of radius. + + Args: + max_radius (float): The maximum radius of the balls. + If None is given, we will use kNN sampling instead of ball query. + sample_num (int): Maximum number of features to gather in the ball. + min_radius (float, optional): The minimum radius of the balls. + Default: 0. + use_xyz (bool, optional): Whether to use xyz. + Default: True. + return_grouped_xyz (bool, optional): Whether to return grouped xyz. + Default: False. + normalize_xyz (bool, optional): Whether to normalize xyz. + Default: False. + uniform_sample (bool, optional): Whether to sample uniformly. + Default: False + return_unique_cnt (bool, optional): Whether to return the count of + unique samples. Default: False. + return_grouped_idx (bool, optional): Whether to return grouped idx. + Default: False. + """ + + def __init__(self, + max_radius, + sample_num, + min_radius=0, + use_xyz=True, + return_grouped_xyz=False, + normalize_xyz=False, + uniform_sample=False, + return_unique_cnt=False, + return_grouped_idx=False): + super().__init__() + self.max_radius = max_radius + self.min_radius = min_radius + self.sample_num = sample_num + self.use_xyz = use_xyz + self.return_grouped_xyz = return_grouped_xyz + self.normalize_xyz = normalize_xyz + self.uniform_sample = uniform_sample + self.return_unique_cnt = return_unique_cnt + self.return_grouped_idx = return_grouped_idx + if self.return_unique_cnt: + assert self.uniform_sample, \ + 'uniform_sample should be True when ' \ + 'returning the count of unique samples' + if self.max_radius is None: + assert not self.normalize_xyz, \ + 'can not normalize grouped xyz when max_radius is None' + + def forward(self, points_xyz, center_xyz, features=None): + """ + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods. + features (Tensor): (B, C, N) Descriptors of the features. + + Returns: + Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. + """ + # if self.max_radius is None, we will perform kNN instead of ball query + # idx is of shape [B, npoint, sample_num] + if self.max_radius is None: + idx = knn(self.sample_num, points_xyz, center_xyz, False) + idx = idx.transpose(1, 2).contiguous() + else: + idx = ball_query(self.min_radius, self.max_radius, self.sample_num, + points_xyz, center_xyz) + + if self.uniform_sample: + unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) + for i_batch in range(idx.shape[0]): + for i_region in range(idx.shape[1]): + unique_ind = torch.unique(idx[i_batch, i_region, :]) + num_unique = unique_ind.shape[0] + unique_cnt[i_batch, i_region] = num_unique + sample_ind = torch.randint( + 0, + num_unique, (self.sample_num - num_unique, ), + dtype=torch.long) + all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) + idx[i_batch, i_region, :] = all_ind + + xyz_trans = points_xyz.transpose(1, 2).contiguous() + # (B, 3, npoint, sample_num) + grouped_xyz = grouping_operation(xyz_trans, idx) + grouped_xyz_diff = grouped_xyz - \ + center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets + if self.normalize_xyz: + grouped_xyz_diff /= self.max_radius + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + # (B, C + 3, npoint, sample_num) + new_features = torch.cat([grouped_xyz_diff, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + assert (self.use_xyz + ), 'Cannot have not features and not use xyz as a feature!' + new_features = grouped_xyz_diff + + ret = [new_features] + if self.return_grouped_xyz: + ret.append(grouped_xyz) + if self.return_unique_cnt: + ret.append(unique_cnt) + if self.return_grouped_idx: + ret.append(idx) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + + +class GroupAll(nn.Module): + """Group xyz with feature. + + Args: + use_xyz (bool): Whether to use xyz. + """ + + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, + xyz: torch.Tensor, + new_xyz: torch.Tensor, + features: torch.Tensor = None): + """ + Args: + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + new_xyz (Tensor): new xyz coordinates of the features. + features (Tensor): (B, C, N) features to group. + + Returns: + Tensor: (B, C + 3, 1, N) Grouped feature. + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + # (B, 3 + C, 1, N) + new_features = torch.cat([grouped_xyz, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features + + +class GroupingOperation(Function): + """Group feature with given index.""" + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, N) tensor of features to group. + indices (Tensor): (B, npoint, nsample) the indices of + features to group with. + + Returns: + Tensor: (B, C, npoint, nsample) Grouped features. + """ + features = features.contiguous() + indices = indices.contiguous() + + B, nfeatures, nsample = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + ext_module.group_points_forward(B, C, N, nfeatures, nsample, features, + indices, output) + + ctx.for_backwards = (indices, N) + return output + + @staticmethod + def backward(ctx, + grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients + of the output from forward. + + Returns: + Tensor: (B, C, N) gradient of the features. + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + + grad_out_data = grad_out.data.contiguous() + ext_module.group_points_backward(B, C, N, npoint, nsample, + grad_out_data, idx, + grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply diff --git a/annotator/uniformer/mmcv/ops/info.py b/annotator/uniformer/mmcv/ops/info.py new file mode 100644 index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d --- /dev/null +++ b/annotator/uniformer/mmcv/ops/info.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os + +import torch + +if torch.__version__ == 'parrots': + import parrots + + def get_compiler_version(): + return 'GCC ' + parrots.version.compiler + + def get_compiling_cuda_version(): + return parrots.version.cuda +else: + from ..utils import ext_loader + ext_module = ext_loader.load_ext( + '_ext', ['get_compiler_version', 'get_compiling_cuda_version']) + + def get_compiler_version(): + return ext_module.get_compiler_version() + + def get_compiling_cuda_version(): + return ext_module.get_compiling_cuda_version() + + +def get_onnxruntime_op_path(): + wildcard = os.path.join( + os.path.abspath(os.path.dirname(os.path.dirname(__file__))), + '_ext_ort.*.so') + + paths = glob.glob(wildcard) + if len(paths) > 0: + return paths[0] + else: + return '' diff --git a/annotator/uniformer/mmcv/ops/iou3d.py b/annotator/uniformer/mmcv/ops/iou3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b --- /dev/null +++ b/annotator/uniformer/mmcv/ops/iou3d.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward', + 'iou3d_nms_normal_forward' +]) + + +def boxes_iou_bev(boxes_a, boxes_b): + """Calculate boxes IoU in the Bird's Eye View. + + Args: + boxes_a (torch.Tensor): Input boxes a with shape (M, 5). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + + Returns: + ans_iou (torch.Tensor): IoU result with shape (M, N). + """ + ans_iou = boxes_a.new_zeros( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) + + ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(), + boxes_b.contiguous(), ans_iou) + + return ans_iou + + +def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): + """NMS function GPU implementation (for BEV boxes). The overlap of two + boxes for IoU calculation is defined as the exact overlapping area of the + two boxes. In this function, one can also set ``pre_max_size`` and + ``post_max_size``. + + Args: + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (float): Overlap threshold of NMS. + pre_max_size (int, optional): Max size of boxes before NMS. + Default: None. + post_max_size (int, optional): Max size of boxes after NMS. + Default: None. + + Returns: + torch.Tensor: Indexes after NMS. + """ + assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' + order = scores.sort(0, descending=True)[1] + + if pre_max_size is not None: + order = order[:pre_max_size] + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh) + keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + if post_max_size is not None: + keep = keep[:post_max_size] + return keep + + +def nms_normal_bev(boxes, scores, thresh): + """Normal NMS function GPU implementation (for BEV boxes). The overlap of + two boxes for IoU calculation is defined as the exact overlapping area of + the two boxes WITH their yaw angle set to 0. + + Args: + boxes (torch.Tensor): Input boxes with shape (N, 5). + scores (torch.Tensor): Scores of predicted boxes with shape (N). + thresh (float): Overlap threshold of NMS. + + Returns: + torch.Tensor: Remaining indices with scores in descending order. + """ + assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' + order = scores.sort(0, descending=True)[1] + + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh) + return order[keep[:num_out].cuda(boxes.device)].contiguous() diff --git a/annotator/uniformer/mmcv/ops/knn.py b/annotator/uniformer/mmcv/ops/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/knn.py @@ -0,0 +1,77 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['knn_forward']) + + +class KNN(Function): + r"""KNN (CUDA) based on heap data structure. + Modified from `PAConv `_. + + Find k-nearest points. + """ + + @staticmethod + def forward(ctx, + k: int, + xyz: torch.Tensor, + center_xyz: torch.Tensor = None, + transposed: bool = False) -> torch.Tensor: + """ + Args: + k (int): number of nearest neighbors. + xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N). + xyz coordinates of the features. + center_xyz (Tensor, optional): (B, npoint, 3) if transposed == + False, else (B, 3, npoint). centers of the knn query. + Default: None. + transposed (bool, optional): whether the input tensors are + transposed. Should not explicitly use this keyword when + calling knn (=KNN.apply), just add the fourth param. + Default: False. + + Returns: + Tensor: (B, k, npoint) tensor with the indices of + the features that form k-nearest neighbours. + """ + assert (k > 0) & (k < 100), 'k should be in range(0, 100)' + + if center_xyz is None: + center_xyz = xyz + + if transposed: + xyz = xyz.transpose(2, 1).contiguous() + center_xyz = center_xyz.transpose(2, 1).contiguous() + + assert xyz.is_contiguous() # [B, N, 3] + assert center_xyz.is_contiguous() # [B, npoint, 3] + + center_xyz_device = center_xyz.get_device() + assert center_xyz_device == xyz.get_device(), \ + 'center_xyz and xyz should be put on the same device' + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) + + B, npoint, _ = center_xyz.shape + N = xyz.shape[1] + + idx = center_xyz.new_zeros((B, npoint, k)).int() + dist2 = center_xyz.new_zeros((B, npoint, k)).float() + + ext_module.knn_forward( + xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k) + # idx shape to [B, k, npoint] + idx = idx.transpose(2, 1).contiguous() + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None + + +knn = KNN.apply diff --git a/annotator/uniformer/mmcv/ops/masked_conv.py b/annotator/uniformer/mmcv/ops/masked_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b --- /dev/null +++ b/annotator/uniformer/mmcv/ops/masked_conv.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['masked_im2col_forward', 'masked_col2im_forward']) + + +class MaskedConv2dFunction(Function): + + @staticmethod + def symbolic(g, features, mask, weight, bias, padding, stride): + return g.op( + 'mmcv::MMCVMaskedConv2d', + features, + mask, + weight, + bias, + padding_i=padding, + stride_i=stride) + + @staticmethod + def forward(ctx, features, mask, weight, bias, padding=0, stride=1): + assert mask.dim() == 3 and mask.size(0) == 1 + assert features.dim() == 4 and features.size(0) == 1 + assert features.size()[2:] == mask.size()[1:] + pad_h, pad_w = _pair(padding) + stride_h, stride_w = _pair(stride) + if stride_h != 1 or stride_w != 1: + raise ValueError( + 'Stride could not only be 1 in masked_conv2d currently.') + out_channel, in_channel, kernel_h, kernel_w = weight.size() + + batch_size = features.size(0) + out_h = int( + math.floor((features.size(2) + 2 * pad_h - + (kernel_h - 1) - 1) / stride_h + 1)) + out_w = int( + math.floor((features.size(3) + 2 * pad_w - + (kernel_h - 1) - 1) / stride_w + 1)) + mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False) + output = features.new_zeros(batch_size, out_channel, out_h, out_w) + if mask_inds.numel() > 0: + mask_h_idx = mask_inds[:, 0].contiguous() + mask_w_idx = mask_inds[:, 1].contiguous() + data_col = features.new_zeros(in_channel * kernel_h * kernel_w, + mask_inds.size(0)) + ext_module.masked_im2col_forward( + features, + mask_h_idx, + mask_w_idx, + data_col, + kernel_h=kernel_h, + kernel_w=kernel_w, + pad_h=pad_h, + pad_w=pad_w) + + masked_output = torch.addmm(1, bias[:, None], 1, + weight.view(out_channel, -1), data_col) + ext_module.masked_col2im_forward( + masked_output, + mask_h_idx, + mask_w_idx, + output, + height=out_h, + width=out_w, + channels=out_channel) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + return (None, ) * 5 + + +masked_conv2d = MaskedConv2dFunction.apply + + +class MaskedConv2d(nn.Conv2d): + """A MaskedConv2d which inherits the official Conv2d. + + The masked forward doesn't implement the backward function and only + supports the stride parameter to be 1 currently. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super(MaskedConv2d, + self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, input, mask=None): + if mask is None: # fallback to the normal Conv2d + return super(MaskedConv2d, self).forward(input) + else: + return masked_conv2d(input, mask, self.weight, self.bias, + self.padding) diff --git a/annotator/uniformer/mmcv/ops/merge_cells.py b/annotator/uniformer/mmcv/ops/merge_cells.py new file mode 100644 index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf --- /dev/null +++ b/annotator/uniformer/mmcv/ops/merge_cells.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..cnn import ConvModule + + +class BaseMergeCell(nn.Module): + """The basic class for cells used in NAS-FPN and NAS-FCOS. + + BaseMergeCell takes 2 inputs. After applying convolution + on them, they are resized to the target size. Then, + they go through binary_op, which depends on the type of cell. + If with_out_conv is True, the result of output will go through + another convolution layer. + + Args: + in_channels (int): number of input channels in out_conv layer. + out_channels (int): number of output channels in out_conv layer. + with_out_conv (bool): Whether to use out_conv layer + out_conv_cfg (dict): Config dict for convolution layer, which should + contain "groups", "kernel_size", "padding", "bias" to build + out_conv layer. + out_norm_cfg (dict): Config dict for normalization layer in out_conv. + out_conv_order (tuple): The order of conv/norm/activation layers in + out_conv. + with_input1_conv (bool): Whether to use convolution on input1. + with_input2_conv (bool): Whether to use convolution on input2. + input_conv_cfg (dict): Config dict for building input1_conv layer and + input2_conv layer, which is expected to contain the type of + convolution. + Default: None, which means using conv2d. + input_norm_cfg (dict): Config dict for normalization layer in + input1_conv and input2_conv layer. Default: None. + upsample_mode (str): Interpolation method used to resize the output + of input1_conv and input2_conv to target size. Currently, we + support ['nearest', 'bilinear']. Default: 'nearest'. + """ + + def __init__(self, + fused_channels=256, + out_channels=256, + with_out_conv=True, + out_conv_cfg=dict( + groups=1, kernel_size=3, padding=1, bias=True), + out_norm_cfg=None, + out_conv_order=('act', 'conv', 'norm'), + with_input1_conv=False, + with_input2_conv=False, + input_conv_cfg=None, + input_norm_cfg=None, + upsample_mode='nearest'): + super(BaseMergeCell, self).__init__() + assert upsample_mode in ['nearest', 'bilinear'] + self.with_out_conv = with_out_conv + self.with_input1_conv = with_input1_conv + self.with_input2_conv = with_input2_conv + self.upsample_mode = upsample_mode + + if self.with_out_conv: + self.out_conv = ConvModule( + fused_channels, + out_channels, + **out_conv_cfg, + norm_cfg=out_norm_cfg, + order=out_conv_order) + + self.input1_conv = self._build_input_conv( + out_channels, input_conv_cfg, + input_norm_cfg) if with_input1_conv else nn.Sequential() + self.input2_conv = self._build_input_conv( + out_channels, input_conv_cfg, + input_norm_cfg) if with_input2_conv else nn.Sequential() + + def _build_input_conv(self, channel, conv_cfg, norm_cfg): + return ConvModule( + channel, + channel, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True) + + @abstractmethod + def _binary_op(self, x1, x2): + pass + + def _resize(self, x, size): + if x.shape[-2:] == size: + return x + elif x.shape[-2:] < size: + return F.interpolate(x, size=size, mode=self.upsample_mode) + else: + assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0 + kernel_size = x.shape[-1] // size[-1] + x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size) + return x + + def forward(self, x1, x2, out_size=None): + assert x1.shape[:2] == x2.shape[:2] + assert out_size is None or len(out_size) == 2 + if out_size is None: # resize to larger one + out_size = max(x1.size()[2:], x2.size()[2:]) + + x1 = self.input1_conv(x1) + x2 = self.input2_conv(x2) + + x1 = self._resize(x1, out_size) + x2 = self._resize(x2, out_size) + + x = self._binary_op(x1, x2) + if self.with_out_conv: + x = self.out_conv(x) + return x + + +class SumCell(BaseMergeCell): + + def __init__(self, in_channels, out_channels, **kwargs): + super(SumCell, self).__init__(in_channels, out_channels, **kwargs) + + def _binary_op(self, x1, x2): + return x1 + x2 + + +class ConcatCell(BaseMergeCell): + + def __init__(self, in_channels, out_channels, **kwargs): + super(ConcatCell, self).__init__(in_channels * 2, out_channels, + **kwargs) + + def _binary_op(self, x1, x2): + ret = torch.cat([x1, x2], dim=1) + return ret + + +class GlobalPoolingCell(BaseMergeCell): + + def __init__(self, in_channels=None, out_channels=None, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + + def _binary_op(self, x1, x2): + x2_att = self.global_pool(x2).sigmoid() + return x2 + x2_att * x1 diff --git a/annotator/uniformer/mmcv/ops/modulated_deform_conv.py b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..75559579cf053abcc99538606cbb88c723faf783 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py @@ -0,0 +1,282 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair, _single + +from annotator.uniformer.mmcv.utils import deprecated_api_warning +from ..cnn import CONV_LAYERS +from ..utils import ext_loader, print_log + +ext_module = ext_loader.load_ext( + '_ext', + ['modulated_deform_conv_forward', 'modulated_deform_conv_backward']) + + +class ModulatedDeformConv2dFunction(Function): + + @staticmethod + def symbolic(g, input, offset, mask, weight, bias, stride, padding, + dilation, groups, deform_groups): + input_tensors = [input, offset, mask, weight] + if bias is not None: + input_tensors.append(bias) + return g.op( + 'mmcv::MMCVModulatedDeformConv2d', + *input_tensors, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups) + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1): + if input is not None and input.dim() != 4: + raise ValueError( + f'Expected 4D tensor as input, got {input.dim()}D tensor \ + instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deform_groups = deform_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(0) # fake tensor + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty( + ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + ext_module.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + kernel_h=weight.size(2), + kernel_w=weight.size(3), + stride_h=ctx.stride[0], + stride_w=ctx.stride[1], + pad_h=ctx.padding[0], + pad_w=ctx.padding[1], + dilation_h=ctx.dilation[0], + dilation_w=ctx.dilation[1], + group=ctx.groups, + deformable_group=ctx.deform_groups, + with_bias=ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + grad_output = grad_output.contiguous() + ext_module.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + kernel_h=weight.size(2), + kernel_w=weight.size(3), + stride_h=ctx.stride[0], + stride_w=ctx.stride[1], + pad_h=ctx.padding[0], + pad_w=ctx.padding[1], + dilation_h=ctx.dilation[0], + dilation_w=ctx.dilation[1], + group=ctx.groups, + deformable_group=ctx.deform_groups, + with_bias=ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None) + + @staticmethod + def _output_size(ctx, input, weight): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = ctx.padding[d] + kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = ctx.stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + 'convolution input is too small (output would be ' + + 'x'.join(map(str, output_size)) + ')') + return output_size + + +modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply + + +class ModulatedDeformConv2d(nn.Module): + + @deprecated_api_warning({'deformable_groups': 'deform_groups'}, + cls_name='ModulatedDeformConv2d') + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=True): + super(ModulatedDeformConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, + *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + +@CONV_LAYERS.register_module('DCNv2') +class ModulatedDeformConv2dPack(ModulatedDeformConv2d): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv + layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConv2dPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, ModulatedDeformConvPack + # loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + print_log( + f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' + 'version 2.', + logger='root') + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c52dda18b41705705b47dd0e995b124048c16fba --- /dev/null +++ b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py @@ -0,0 +1,358 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd.function import Function, once_differentiable + +from annotator.uniformer.mmcv import deprecated_api_warning +from annotator.uniformer.mmcv.cnn import constant_init, xavier_init +from annotator.uniformer.mmcv.cnn.bricks.registry import ATTENTION +from annotator.uniformer.mmcv.runner import BaseModule +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) + + +class MultiScaleDeformableAttnFunction(Function): + + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights, im2col_step): + """GPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + im2col_step (Tensor): The step used in image to column. + + Returns: + Tensor: has shape (bs, num_queries, embed_dims) + """ + + ctx.im2col_step = im2col_step + output = ext_module.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step=ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, + value_level_start_index, sampling_locations, + attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + """GPU version of backward function. + + Args: + grad_output (Tensor): Gradient + of output tensor of forward. + + Returns: + Tuple[Tensor]: Gradient + of input tensors in forward. + """ + value, value_spatial_shapes, value_level_start_index,\ + sampling_locations, attention_weights = ctx.saved_tensors + grad_value = torch.zeros_like(value) + grad_sampling_loc = torch.zeros_like(sampling_locations) + grad_attn_weight = torch.zeros_like(attention_weights) + + ext_module.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output.contiguous(), + grad_value, + grad_sampling_loc, + grad_attn_weight, + im2col_step=ctx.im2col_step) + + return grad_value, None, None, \ + grad_sampling_loc, grad_attn_weight, None + + +def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, + sampling_locations, attention_weights): + """CPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + + Returns: + Tensor: has shape (bs, num_queries, embed_dims) + """ + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ =\ + sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], + dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape( + bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, + level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * + attention_weights).sum(-1).view(bs, num_heads * embed_dims, + num_queries) + return output.transpose(1, 2).contiguous() + + +@ATTENTION.register_module() +class MultiScaleDeformableAttention(BaseModule): + """An attention module used in Deformable-Detr. + + `Deformable DETR: Deformable Transformers for End-to-End Object Detection. + `_. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for + each query in each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_identity`. + Default: 0.1. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=4, + im2col_step=64, + dropout=0.1, + batch_first=False, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + if embed_dims % num_heads != 0: + raise ValueError(f'embed_dims must be divisible by num_heads, ' + f'but got {embed_dims} and {num_heads}') + dim_per_head = embed_dims // num_heads + self.norm_cfg = norm_cfg + self.dropout = nn.Dropout(dropout) + self.batch_first = batch_first + + # you'd better set dim_per_head to a power of 2 + # which is more efficient in the CUDA implementation + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + if not _is_power_of_2(dim_per_head): + warnings.warn( + "You'd better set embed_dims in " + 'MultiScaleDeformAttention to make ' + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2) + self.attention_weights = nn.Linear(embed_dims, + num_heads * num_levels * num_points) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weights() + + def init_weights(self): + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.) + thetas = torch.arange( + self.num_heads, + dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / + grid_init.abs().max(-1, keepdim=True)[0]).view( + self.num_heads, 1, 1, + 2).repeat(1, self.num_levels, self.num_points, 1) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0., bias=0.) + xavier_init(self.value_proj, distribution='uniform', bias=0.) + xavier_init(self.output_proj, distribution='uniform', bias=0.) + self._is_init = True + + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiScaleDeformableAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + key (Tensor): The key tensor with shape + `(num_key, bs, embed_dims)`. + value (Tensor): The value tensor with shape + `(num_key, bs, embed_dims)`. + identity (Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. Default + None. + reference_points (Tensor): The normalized reference + points with shape (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + spatial_shapes (Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + + if value is None: + value = query + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available() and value.is_cuda: + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + output = self.output_proj(output) + + if not self.batch_first: + # (num_query, bs ,embed_dims) + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity diff --git a/annotator/uniformer/mmcv/ops/nms.py b/annotator/uniformer/mmcv/ops/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9634281f486ab284091786886854c451368052 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/nms.py @@ -0,0 +1,417 @@ +import os + +import numpy as np +import torch + +from annotator.uniformer.mmcv.utils import deprecated_api_warning +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated']) + + +# This function is modified from: https://github.com/pytorch/vision/ +class NMSop(torch.autograd.Function): + + @staticmethod + def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold, + max_num): + is_filtering_by_score = score_threshold > 0 + if is_filtering_by_score: + valid_mask = scores > score_threshold + bboxes, scores = bboxes[valid_mask], scores[valid_mask] + valid_inds = torch.nonzero( + valid_mask, as_tuple=False).squeeze(dim=1) + + inds = ext_module.nms( + bboxes, scores, iou_threshold=float(iou_threshold), offset=offset) + + if max_num > 0: + inds = inds[:max_num] + if is_filtering_by_score: + inds = valid_inds[inds] + return inds + + @staticmethod + def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold, + max_num): + from ..onnx import is_custom_op_loaded + has_custom_op = is_custom_op_loaded() + # TensorRT nms plugin is aligned with original nms in ONNXRuntime + is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT' + if has_custom_op and (not is_trt_backend): + return g.op( + 'mmcv::NonMaxSuppression', + bboxes, + scores, + iou_threshold_f=float(iou_threshold), + offset_i=int(offset)) + else: + from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze + from ..onnx.onnx_utils.symbolic_helper import _size_helper + + boxes = unsqueeze(g, bboxes, 0) + scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) + + if max_num > 0: + max_num = g.op( + 'Constant', + value_t=torch.tensor(max_num, dtype=torch.long)) + else: + dim = g.op('Constant', value_t=torch.tensor(0)) + max_num = _size_helper(g, bboxes, dim) + max_output_per_class = max_num + iou_threshold = g.op( + 'Constant', + value_t=torch.tensor([iou_threshold], dtype=torch.float)) + score_threshold = g.op( + 'Constant', + value_t=torch.tensor([score_threshold], dtype=torch.float)) + nms_out = g.op('NonMaxSuppression', boxes, scores, + max_output_per_class, iou_threshold, + score_threshold) + return squeeze( + g, + select( + g, nms_out, 1, + g.op( + 'Constant', + value_t=torch.tensor([2], dtype=torch.long))), 1) + + +class SoftNMSop(torch.autograd.Function): + + @staticmethod + def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method, + offset): + dets = boxes.new_empty((boxes.size(0), 5), device='cpu') + inds = ext_module.softnms( + boxes.cpu(), + scores.cpu(), + dets.cpu(), + iou_threshold=float(iou_threshold), + sigma=float(sigma), + min_score=float(min_score), + method=int(method), + offset=int(offset)) + return dets, inds + + @staticmethod + def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method, + offset): + from packaging import version + assert version.parse(torch.__version__) >= version.parse('1.7.0') + nms_out = g.op( + 'mmcv::SoftNonMaxSuppression', + boxes, + scores, + iou_threshold_f=float(iou_threshold), + sigma_f=float(sigma), + min_score_f=float(min_score), + method_i=int(method), + offset_i=int(offset), + outputs=2) + return nms_out + + +@deprecated_api_warning({'iou_thr': 'iou_threshold'}) +def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1): + """Dispatch to either CPU or GPU NMS implementations. + + The input can be either torch tensor or numpy array. GPU NMS will be used + if the input is gpu tensor, otherwise CPU NMS + will be used. The returned type will always be the same as inputs. + + Arguments: + boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4). + scores (torch.Tensor or np.ndarray): scores in shape (N, ). + iou_threshold (float): IoU threshold for NMS. + offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). + score_threshold (float): score threshold for NMS. + max_num (int): maximum number of boxes after NMS. + + Returns: + tuple: kept dets(boxes and scores) and indice, which is always the \ + same data type as the input. + + Example: + >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9], + >>> [49.3, 32.9, 51.0, 35.3], + >>> [49.2, 31.8, 51.0, 35.4], + >>> [35.1, 11.5, 39.1, 15.7], + >>> [35.6, 11.8, 39.3, 14.2], + >>> [35.3, 11.5, 39.9, 14.5], + >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32) + >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\ + dtype=np.float32) + >>> iou_threshold = 0.6 + >>> dets, inds = nms(boxes, scores, iou_threshold) + >>> assert len(inds) == len(dets) == 3 + """ + assert isinstance(boxes, (torch.Tensor, np.ndarray)) + assert isinstance(scores, (torch.Tensor, np.ndarray)) + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = torch.from_numpy(boxes) + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + assert boxes.size(1) == 4 + assert boxes.size(0) == scores.size(0) + assert offset in (0, 1) + + if torch.__version__ == 'parrots': + indata_list = [boxes, scores] + indata_dict = { + 'iou_threshold': float(iou_threshold), + 'offset': int(offset) + } + inds = ext_module.nms(*indata_list, **indata_dict) + else: + inds = NMSop.apply(boxes, scores, iou_threshold, offset, + score_threshold, max_num) + dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) + if is_numpy: + dets = dets.cpu().numpy() + inds = inds.cpu().numpy() + return dets, inds + + +@deprecated_api_warning({'iou_thr': 'iou_threshold'}) +def soft_nms(boxes, + scores, + iou_threshold=0.3, + sigma=0.5, + min_score=1e-3, + method='linear', + offset=0): + """Dispatch to only CPU Soft NMS implementations. + + The input can be either a torch tensor or numpy array. + The returned type will always be the same as inputs. + + Arguments: + boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4). + scores (torch.Tensor or np.ndarray): scores in shape (N, ). + iou_threshold (float): IoU threshold for NMS. + sigma (float): hyperparameter for gaussian method + min_score (float): score filter threshold + method (str): either 'linear' or 'gaussian' + offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). + + Returns: + tuple: kept dets(boxes and scores) and indice, which is always the \ + same data type as the input. + + Example: + >>> boxes = np.array([[4., 3., 5., 3.], + >>> [4., 3., 5., 4.], + >>> [3., 1., 3., 1.], + >>> [3., 1., 3., 1.], + >>> [3., 1., 3., 1.], + >>> [3., 1., 3., 1.]], dtype=np.float32) + >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32) + >>> iou_threshold = 0.6 + >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5) + >>> assert len(inds) == len(dets) == 5 + """ + + assert isinstance(boxes, (torch.Tensor, np.ndarray)) + assert isinstance(scores, (torch.Tensor, np.ndarray)) + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = torch.from_numpy(boxes) + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + assert boxes.size(1) == 4 + assert boxes.size(0) == scores.size(0) + assert offset in (0, 1) + method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2} + assert method in method_dict.keys() + + if torch.__version__ == 'parrots': + dets = boxes.new_empty((boxes.size(0), 5), device='cpu') + indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()] + indata_dict = { + 'iou_threshold': float(iou_threshold), + 'sigma': float(sigma), + 'min_score': min_score, + 'method': method_dict[method], + 'offset': int(offset) + } + inds = ext_module.softnms(*indata_list, **indata_dict) + else: + dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(), + float(iou_threshold), float(sigma), + float(min_score), method_dict[method], + int(offset)) + + dets = dets[:inds.size(0)] + + if is_numpy: + dets = dets.cpu().numpy() + inds = inds.cpu().numpy() + return dets, inds + else: + return dets.to(device=boxes.device), inds.to(device=boxes.device) + + +def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): + """Performs non-maximum suppression in a batched fashion. + + Modified from https://github.com/pytorch/vision/blob + /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + + Arguments: + boxes (torch.Tensor): boxes in shape (N, 4). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict): specify nms type and other parameters like iou_thr. + Possible keys includes the following. + + - iou_thr (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. + + Returns: + tuple: kept dets and indice. + """ + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_for_nms = boxes + offsets[:, None] + + nms_type = nms_cfg_.pop('type', 'nms') + nms_op = eval(nms_type) + + split_thr = nms_cfg_.pop('split_thr', 10000) + # Won't split to multiple nms nodes when exporting to onnx + if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export(): + dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + # -1 indexing works abnormal in TensorRT + # This assumes `dets` has 5 dimensions where + # the last dimension is score. + # TODO: more elegant way to handle the dimension issue. + # Some type of nms would reweight the score, such as SoftNMS + scores = dets[:, 4] + else: + max_num = nms_cfg_.pop('max_num', -1) + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + # Some type of nms would reweight the score, such as SoftNMS + scores_after_nms = scores.new_zeros(scores.size()) + for id in torch.unique(idxs): + mask = (idxs == id).nonzero(as_tuple=False).view(-1) + dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + scores_after_nms[mask[keep]] = dets[:, -1] + keep = total_mask.nonzero(as_tuple=False).view(-1) + + scores, inds = scores_after_nms[keep].sort(descending=True) + keep = keep[inds] + boxes = boxes[keep] + + if max_num > 0: + keep = keep[:max_num] + boxes = boxes[:max_num] + scores = scores[:max_num] + + return torch.cat([boxes, scores[:, None]], -1), keep + + +def nms_match(dets, iou_threshold): + """Matched dets into different groups by NMS. + + NMS match is Similar to NMS but when a bbox is suppressed, nms match will + record the indice of suppressed bbox and form a group with the indice of + kept bbox. In each group, indice is sorted as score order. + + Arguments: + dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5). + iou_thr (float): IoU thresh for NMS. + + Returns: + List[torch.Tensor | np.ndarray]: The outer list corresponds different + matched group, the inner Tensor corresponds the indices for a group + in score order. + """ + if dets.shape[0] == 0: + matched = [] + else: + assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \ + f'but get {dets.shape}' + if isinstance(dets, torch.Tensor): + dets_t = dets.detach().cpu() + else: + dets_t = torch.from_numpy(dets) + indata_list = [dets_t] + indata_dict = {'iou_threshold': float(iou_threshold)} + matched = ext_module.nms_match(*indata_list, **indata_dict) + if torch.__version__ == 'parrots': + matched = matched.tolist() + + if isinstance(dets, torch.Tensor): + return [dets.new_tensor(m, dtype=torch.long) for m in matched] + else: + return [np.array(m, dtype=np.int) for m in matched] + + +def nms_rotated(dets, scores, iou_threshold, labels=None): + """Performs non-maximum suppression (NMS) on the rotated boxes according to + their intersection-over-union (IoU). + + Rotated NMS iteratively removes lower scoring rotated boxes which have an + IoU greater than iou_threshold with another (higher scoring) rotated box. + + Args: + boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \ + be in (x_ctr, y_ctr, width, height, angle_radian) format. + scores (Tensor): scores in shape (N, ). + iou_threshold (float): IoU thresh for NMS. + labels (Tensor): boxes' label in shape (N,). + + Returns: + tuple: kept dets(boxes and scores) and indice, which is always the \ + same data type as the input. + """ + if dets.shape[0] == 0: + return dets, None + multi_label = labels is not None + if multi_label: + dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1) + else: + dets_wl = dets + _, order = scores.sort(0, descending=True) + dets_sorted = dets_wl.index_select(0, order) + + if torch.__version__ == 'parrots': + keep_inds = ext_module.nms_rotated( + dets_wl, + scores, + order, + dets_sorted, + iou_threshold=iou_threshold, + multi_label=multi_label) + else: + keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted, + iou_threshold, multi_label) + dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), + dim=1) + return dets, keep_inds diff --git a/annotator/uniformer/mmcv/ops/pixel_group.py b/annotator/uniformer/mmcv/ops/pixel_group.py new file mode 100644 index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/pixel_group.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['pixel_group']) + + +def pixel_group(score, mask, embedding, kernel_label, kernel_contour, + kernel_region_num, distance_threshold): + """Group pixels into text instances, which is widely used text detection + methods. + + Arguments: + score (np.array or Tensor): The foreground score with size hxw. + mask (np.array or Tensor): The foreground mask with size hxw. + embedding (np.array or Tensor): The embedding with size hxwxc to + distinguish instances. + kernel_label (np.array or Tensor): The instance kernel index with + size hxw. + kernel_contour (np.array or Tensor): The kernel contour with size hxw. + kernel_region_num (int): The instance kernel region number. + distance_threshold (float): The embedding distance threshold between + kernel and pixel in one instance. + + Returns: + pixel_assignment (List[List[float]]): The instance coordinate list. + Each element consists of averaged confidence, pixel number, and + coordinates (x_i, y_i for all pixels) in order. + """ + assert isinstance(score, (torch.Tensor, np.ndarray)) + assert isinstance(mask, (torch.Tensor, np.ndarray)) + assert isinstance(embedding, (torch.Tensor, np.ndarray)) + assert isinstance(kernel_label, (torch.Tensor, np.ndarray)) + assert isinstance(kernel_contour, (torch.Tensor, np.ndarray)) + assert isinstance(kernel_region_num, int) + assert isinstance(distance_threshold, float) + + if isinstance(score, np.ndarray): + score = torch.from_numpy(score) + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + if isinstance(embedding, np.ndarray): + embedding = torch.from_numpy(embedding) + if isinstance(kernel_label, np.ndarray): + kernel_label = torch.from_numpy(kernel_label) + if isinstance(kernel_contour, np.ndarray): + kernel_contour = torch.from_numpy(kernel_contour) + + if torch.__version__ == 'parrots': + label = ext_module.pixel_group( + score, + mask, + embedding, + kernel_label, + kernel_contour, + kernel_region_num=kernel_region_num, + distance_threshold=distance_threshold) + label = label.tolist() + label = label[0] + list_index = kernel_region_num + pixel_assignment = [] + for x in range(kernel_region_num): + pixel_assignment.append( + np.array( + label[list_index:list_index + int(label[x])], + dtype=np.float)) + list_index = list_index + int(label[x]) + else: + pixel_assignment = ext_module.pixel_group(score, mask, embedding, + kernel_label, kernel_contour, + kernel_region_num, + distance_threshold) + return pixel_assignment diff --git a/annotator/uniformer/mmcv/ops/point_sample.py b/annotator/uniformer/mmcv/ops/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..267f4b3c56630acd85f9bdc630b7be09abab0aba --- /dev/null +++ b/annotator/uniformer/mmcv/ops/point_sample.py @@ -0,0 +1,336 @@ +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa + +from os import path as osp + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair +from torch.onnx.operators import shape_as_tensor + + +def bilinear_grid_sample(im, grid, align_corners=False): + """Given an input and a flow-field grid, computes the output using input + values and pixel locations from grid. Supported only bilinear interpolation + method to sample the input pixels. + + Args: + im (torch.Tensor): Input feature map, shape (N, C, H, W) + grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) + align_corners {bool}: If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + Returns: + torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) + """ + n, c, h, w = im.shape + gn, gh, gw, _ = grid.shape + assert n == gn + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + + if align_corners: + x = ((x + 1) / 2) * (w - 1) + y = ((y + 1) / 2) * (h - 1) + else: + x = ((x + 1) * w - 1) / 2 + y = ((y + 1) * h - 1) / 2 + + x = x.view(n, -1) + y = y.view(n, -1) + + x0 = torch.floor(x).long() + y0 = torch.floor(y).long() + x1 = x0 + 1 + y1 = y0 + 1 + + wa = ((x1 - x) * (y1 - y)).unsqueeze(1) + wb = ((x1 - x) * (y - y0)).unsqueeze(1) + wc = ((x - x0) * (y1 - y)).unsqueeze(1) + wd = ((x - x0) * (y - y0)).unsqueeze(1) + + # Apply default for grid_sample function zero padding + im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) + padded_h = h + 2 + padded_w = w + 2 + # save points positions after padding + x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 + + # Clip coordinates to padded image size + x0 = torch.where(x0 < 0, torch.tensor(0), x0) + x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0) + x1 = torch.where(x1 < 0, torch.tensor(0), x1) + x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1) + y0 = torch.where(y0 < 0, torch.tensor(0), y0) + y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0) + y1 = torch.where(y1 < 0, torch.tensor(0), y1) + y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1) + + im_padded = im_padded.view(n, c, -1) + + x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + + Ia = torch.gather(im_padded, 2, x0_y0) + Ib = torch.gather(im_padded, 2, x0_y1) + Ic = torch.gather(im_padded, 2, x1_y0) + Id = torch.gather(im_padded, 2, x1_y1) + + return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) + + +def is_in_onnx_export_without_custom_ops(): + from annotator.uniformer.mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + return torch.onnx.is_in_onnx_export( + ) and not osp.exists(ort_custom_op_path) + + +def normalize(grid): + """Normalize input grid from [-1, 1] to [0, 1] + Args: + grid (Tensor): The grid to be normalize, range [-1, 1]. + Returns: + Tensor: Normalized grid, range [0, 1]. + """ + + return (grid + 1.0) / 2.0 + + +def denormalize(grid): + """Denormalize input grid from range [0, 1] to [-1, 1] + Args: + grid (Tensor): The grid to be denormalize, range [0, 1]. + Returns: + Tensor: Denormalized grid, range [-1, 1]. + """ + + return grid * 2.0 - 1.0 + + +def generate_grid(num_grid, size, device): + """Generate regular square grid of points in [0, 1] x [0, 1] coordinate + space. + + Args: + num_grid (int): The number of grids to sample, one for each region. + size (tuple(int, int)): The side size of the regular grid. + device (torch.device): Desired device of returned tensor. + + Returns: + (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that + contains coordinates for the regular grids. + """ + + affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device) + grid = F.affine_grid( + affine_trans, torch.Size((1, 1, *size)), align_corners=False) + grid = normalize(grid) + return grid.view(1, -1, 2).expand(num_grid, -1, -1) + + +def rel_roi_point_to_abs_img_point(rois, rel_roi_points): + """Convert roi based relative point coordinates to image based absolute + point coordinates. + + Args: + rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) + rel_roi_points (Tensor): Point coordinates inside RoI, relative to + RoI, location, range (0, 1), shape (N, P, 2) + Returns: + Tensor: Image based absolute point coordinates, shape (N, P, 2) + """ + + with torch.no_grad(): + assert rel_roi_points.size(0) == rois.size(0) + assert rois.dim() == 2 + assert rel_roi_points.dim() == 3 + assert rel_roi_points.size(2) == 2 + # remove batch idx + if rois.size(1) == 5: + rois = rois[:, 1:] + abs_img_points = rel_roi_points.clone() + # To avoid an error during exporting to onnx use independent + # variables instead inplace computation + xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0]) + ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1]) + xs += rois[:, None, 0] + ys += rois[:, None, 1] + abs_img_points = torch.stack([xs, ys], dim=2) + return abs_img_points + + +def get_shape_from_feature_map(x): + """Get spatial resolution of input feature map considering exporting to + onnx mode. + + Args: + x (torch.Tensor): Input tensor, shape (N, C, H, W) + Returns: + torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2) + """ + if torch.onnx.is_in_onnx_export(): + img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to( + x.device).float() + else: + img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to( + x.device).float() + return img_shape + + +def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.): + """Convert image based absolute point coordinates to image based relative + coordinates for sampling. + + Args: + abs_img_points (Tensor): Image based absolute point coordinates, + shape (N, P, 2) + img (tuple/Tensor): (height, width) of image or feature map. + spatial_scale (float): Scale points by this factor. Default: 1. + + Returns: + Tensor: Image based relative point coordinates for sampling, + shape (N, P, 2) + """ + + assert (isinstance(img, tuple) and len(img) == 2) or \ + (isinstance(img, torch.Tensor) and len(img.shape) == 4) + + if isinstance(img, tuple): + h, w = img + scale = torch.tensor([w, h], + dtype=torch.float, + device=abs_img_points.device) + scale = scale.view(1, 1, 2) + else: + scale = get_shape_from_feature_map(img) + + return abs_img_points / scale * spatial_scale + + +def rel_roi_point_to_rel_img_point(rois, + rel_roi_points, + img, + spatial_scale=1.): + """Convert roi based relative point coordinates to image based absolute + point coordinates. + + Args: + rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) + rel_roi_points (Tensor): Point coordinates inside RoI, relative to + RoI, location, range (0, 1), shape (N, P, 2) + img (tuple/Tensor): (height, width) of image or feature map. + spatial_scale (float): Scale points by this factor. Default: 1. + + Returns: + Tensor: Image based relative point coordinates for sampling, + shape (N, P, 2) + """ + + abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points) + rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img, + spatial_scale) + + return rel_img_point + + +def point_sample(input, points, align_corners=False, **kwargs): + """A wrapper around :func:`grid_sample` to support 3D point_coords tensors + Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to + lie inside ``[0, 1] x [0, 1]`` square. + + Args: + input (Tensor): Feature map, shape (N, C, H, W). + points (Tensor): Image based absolute point coordinates (normalized), + range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2). + align_corners (bool): Whether align_corners. Default: False + + Returns: + Tensor: Features of `point` on `input`, shape (N, C, P) or + (N, C, Hgrid, Wgrid). + """ + + add_dim = False + if points.dim() == 3: + add_dim = True + points = points.unsqueeze(2) + if is_in_onnx_export_without_custom_ops(): + # If custom ops for onnx runtime not compiled use python + # implementation of grid_sample function to make onnx graph + # with supported nodes + output = bilinear_grid_sample( + input, denormalize(points), align_corners=align_corners) + else: + output = F.grid_sample( + input, denormalize(points), align_corners=align_corners, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +class SimpleRoIAlign(nn.Module): + + def __init__(self, output_size, spatial_scale, aligned=True): + """Simple RoI align in PointRend, faster than standard RoIAlign. + + Args: + output_size (tuple[int]): h, w + spatial_scale (float): scale the input boxes by this number + aligned (bool): if False, use the legacy implementation in + MMDetection, align_corners=True will be used in F.grid_sample. + If True, align the results more perfectly. + """ + + super(SimpleRoIAlign, self).__init__() + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + # to be consistent with other RoI ops + self.use_torchvision = False + self.aligned = aligned + + def forward(self, features, rois): + num_imgs = features.size(0) + num_rois = rois.size(0) + rel_roi_points = generate_grid( + num_rois, self.output_size, device=rois.device) + + if torch.onnx.is_in_onnx_export(): + rel_img_points = rel_roi_point_to_rel_img_point( + rois, rel_roi_points, features, self.spatial_scale) + rel_img_points = rel_img_points.reshape(num_imgs, -1, + *rel_img_points.shape[1:]) + point_feats = point_sample( + features, rel_img_points, align_corners=not self.aligned) + point_feats = point_feats.transpose(1, 2) + else: + point_feats = [] + for batch_ind in range(num_imgs): + # unravel batch dim + feat = features[batch_ind].unsqueeze(0) + inds = (rois[:, 0].long() == batch_ind) + if inds.any(): + rel_img_points = rel_roi_point_to_rel_img_point( + rois[inds], rel_roi_points[inds], feat, + self.spatial_scale).unsqueeze(0) + point_feat = point_sample( + feat, rel_img_points, align_corners=not self.aligned) + point_feat = point_feat.squeeze(0).transpose(0, 1) + point_feats.append(point_feat) + + point_feats = torch.cat(point_feats, dim=0) + + channels = features.size(1) + roi_feats = point_feats.reshape(num_rois, channels, *self.output_size) + + return roi_feats + + def __repr__(self): + format_str = self.__class__.__name__ + format_str += '(output_size={}, spatial_scale={}'.format( + self.output_size, self.spatial_scale) + return format_str diff --git a/annotator/uniformer/mmcv/ops/points_in_boxes.py b/annotator/uniformer/mmcv/ops/points_in_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/points_in_boxes.py @@ -0,0 +1,133 @@ +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward', + 'points_in_boxes_all_forward' +]) + + +def points_in_boxes_part(points, boxes): + """Find the box in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in + LiDAR/DEPTH coordinate, (x, y, z) is the bottom center + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M), default background = -1 + """ + assert points.shape[0] == boxes.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + + box_idxs_of_pts = points.new_zeros((batch_size, num_points), + dtype=torch.int).fill_(-1) + + # If manually put the tensor 'points' or 'boxes' on a device + # which is not the current device, some temporary variables + # will be created on the current device in the cuda op, + # and the output will be incorrect. + # Therefore, we force the current device to be the same + # as the device of the tensors if it was not. + # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305 + # for the incorrect output before the fix. + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_part_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts + + +def points_in_boxes_cpu(points, boxes): + """Find all boxes in which each point is (CPU). The CPU version of + :meth:`points_in_boxes_all`. + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in + LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert points.shape[0] == boxes.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + point_indices = points.new_zeros((batch_size, num_boxes, num_points), + dtype=torch.int) + for b in range(batch_size): + ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(), + points[b].float().contiguous(), + point_indices[b]) + point_indices = point_indices.transpose(1, 2) + + return point_indices + + +def points_in_boxes_all(points, boxes): + """Find all boxes in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert boxes.shape[0] == points.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {boxes.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes), + dtype=torch.int).fill_(0) + + # Same reason as line 25-32 + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_all_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts diff --git a/annotator/uniformer/mmcv/ops/points_sampler.py b/annotator/uniformer/mmcv/ops/points_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a802a74fd6c3610d9ae178e6201f47423eca7ad1 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/points_sampler.py @@ -0,0 +1,177 @@ +from typing import List + +import torch +from torch import nn as nn + +from annotator.uniformer.mmcv.runner import force_fp32 +from .furthest_point_sample import (furthest_point_sample, + furthest_point_sample_with_dist) + + +def calc_square_dist(point_feat_a, point_feat_b, norm=True): + """Calculating square distance between a and b. + + Args: + point_feat_a (Tensor): (B, N, C) Feature vector of each point. + point_feat_b (Tensor): (B, M, C) Feature vector of each point. + norm (Bool, optional): Whether to normalize the distance. + Default: True. + + Returns: + Tensor: (B, N, M) Distance between each pair points. + """ + num_channel = point_feat_a.shape[-1] + # [bs, n, 1] + a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1) + # [bs, 1, m] + b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1) + + corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2)) + + dist = a_square + b_square - 2 * corr_matrix + if norm: + dist = torch.sqrt(dist) / num_channel + return dist + + +def get_sampler_cls(sampler_type): + """Get the type and mode of points sampler. + + Args: + sampler_type (str): The type of points sampler. + The valid value are "D-FPS", "F-FPS", or "FS". + + Returns: + class: Points sampler type. + """ + sampler_mappings = { + 'D-FPS': DFPSSampler, + 'F-FPS': FFPSSampler, + 'FS': FSSampler, + } + try: + return sampler_mappings[sampler_type] + except KeyError: + raise KeyError( + f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \ + {sampler_type}') + + +class PointsSampler(nn.Module): + """Points sampling. + + Args: + num_point (list[int]): Number of sample points. + fps_mod_list (list[str], optional): Type of FPS method, valid mod + ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. + F-FPS: using feature distances for FPS. + D-FPS: using Euclidean distances of points for FPS. + FS: using F-FPS and D-FPS simultaneously. + fps_sample_range_list (list[int], optional): + Range of points to apply FPS. Default: [-1]. + """ + + def __init__(self, + num_point: List[int], + fps_mod_list: List[str] = ['D-FPS'], + fps_sample_range_list: List[int] = [-1]): + super().__init__() + # FPS would be applied to different fps_mod in the list, + # so the length of the num_point should be equal to + # fps_mod_list and fps_sample_range_list. + assert len(num_point) == len(fps_mod_list) == len( + fps_sample_range_list) + self.num_point = num_point + self.fps_sample_range_list = fps_sample_range_list + self.samplers = nn.ModuleList() + for fps_mod in fps_mod_list: + self.samplers.append(get_sampler_cls(fps_mod)()) + self.fp16_enabled = False + + @force_fp32() + def forward(self, points_xyz, features): + """ + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + features (Tensor): (B, C, N) Descriptors of the features. + + Returns: + Tensor: (B, npoint, sample_num) Indices of sampled points. + """ + indices = [] + last_fps_end_index = 0 + + for fps_sample_range, sampler, npoint in zip( + self.fps_sample_range_list, self.samplers, self.num_point): + assert fps_sample_range < points_xyz.shape[1] + + if fps_sample_range == -1: + sample_points_xyz = points_xyz[:, last_fps_end_index:] + if features is not None: + sample_features = features[:, :, last_fps_end_index:] + else: + sample_features = None + else: + sample_points_xyz = \ + points_xyz[:, last_fps_end_index:fps_sample_range] + if features is not None: + sample_features = features[:, :, last_fps_end_index: + fps_sample_range] + else: + sample_features = None + + fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, + npoint) + + indices.append(fps_idx + last_fps_end_index) + last_fps_end_index += fps_sample_range + indices = torch.cat(indices, dim=1) + + return indices + + +class DFPSSampler(nn.Module): + """Using Euclidean distances of points for FPS.""" + + def __init__(self): + super().__init__() + + def forward(self, points, features, npoint): + """Sampling points with D-FPS.""" + fps_idx = furthest_point_sample(points.contiguous(), npoint) + return fps_idx + + +class FFPSSampler(nn.Module): + """Using feature distances for FPS.""" + + def __init__(self): + super().__init__() + + def forward(self, points, features, npoint): + """Sampling points with F-FPS.""" + assert features is not None, \ + 'feature input to FFPS_Sampler should not be None' + features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) + features_dist = calc_square_dist( + features_for_fps, features_for_fps, norm=False) + fps_idx = furthest_point_sample_with_dist(features_dist, npoint) + return fps_idx + + +class FSSampler(nn.Module): + """Using F-FPS and D-FPS simultaneously.""" + + def __init__(self): + super().__init__() + + def forward(self, points, features, npoint): + """Sampling points with FS_Sampling.""" + assert features is not None, \ + 'feature input to FS_Sampler should not be None' + ffps_sampler = FFPSSampler() + dfps_sampler = DFPSSampler() + fps_idx_ffps = ffps_sampler(points, features, npoint) + fps_idx_dfps = dfps_sampler(points, features, npoint) + fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1) + return fps_idx diff --git a/annotator/uniformer/mmcv/ops/psa_mask.py b/annotator/uniformer/mmcv/ops/psa_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/psa_mask.py @@ -0,0 +1,92 @@ +# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa +from torch import nn +from torch.autograd import Function +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['psamask_forward', 'psamask_backward']) + + +class PSAMaskFunction(Function): + + @staticmethod + def symbolic(g, input, psa_type, mask_size): + return g.op( + 'mmcv::MMCVPSAMask', + input, + psa_type_i=psa_type, + mask_size_i=mask_size) + + @staticmethod + def forward(ctx, input, psa_type, mask_size): + ctx.psa_type = psa_type + ctx.mask_size = _pair(mask_size) + ctx.save_for_backward(input) + + h_mask, w_mask = ctx.mask_size + batch_size, channels, h_feature, w_feature = input.size() + assert channels == h_mask * w_mask + output = input.new_zeros( + (batch_size, h_feature * w_feature, h_feature, w_feature)) + + ext_module.psamask_forward( + input, + output, + psa_type=psa_type, + num_=batch_size, + h_feature=h_feature, + w_feature=w_feature, + h_mask=h_mask, + w_mask=w_mask, + half_h_mask=(h_mask - 1) // 2, + half_w_mask=(w_mask - 1) // 2) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors[0] + psa_type = ctx.psa_type + h_mask, w_mask = ctx.mask_size + batch_size, channels, h_feature, w_feature = input.size() + grad_input = grad_output.new_zeros( + (batch_size, channels, h_feature, w_feature)) + ext_module.psamask_backward( + grad_output, + grad_input, + psa_type=psa_type, + num_=batch_size, + h_feature=h_feature, + w_feature=w_feature, + h_mask=h_mask, + w_mask=w_mask, + half_h_mask=(h_mask - 1) // 2, + half_w_mask=(w_mask - 1) // 2) + return grad_input, None, None, None + + +psa_mask = PSAMaskFunction.apply + + +class PSAMask(nn.Module): + + def __init__(self, psa_type, mask_size=None): + super(PSAMask, self).__init__() + assert psa_type in ['collect', 'distribute'] + if psa_type == 'collect': + psa_type_enum = 0 + else: + psa_type_enum = 1 + self.psa_type_enum = psa_type_enum + self.mask_size = mask_size + self.psa_type = psa_type + + def forward(self, input): + return psa_mask(input, self.psa_type_enum, self.mask_size) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(psa_type={self.psa_type}, ' + s += f'mask_size={self.mask_size})' + return s diff --git a/annotator/uniformer/mmcv/ops/roi_align.py b/annotator/uniformer/mmcv/ops/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roi_align.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import deprecated_api_warning, ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['roi_align_forward', 'roi_align_backward']) + + +class RoIAlignFunction(Function): + + @staticmethod + def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, + pool_mode, aligned): + from ..onnx import is_custom_op_loaded + has_custom_op = is_custom_op_loaded() + if has_custom_op: + return g.op( + 'mmcv::MMCVRoiAlign', + input, + rois, + output_height_i=output_size[0], + output_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_i=sampling_ratio, + mode_s=pool_mode, + aligned_i=aligned) + else: + from torch.onnx.symbolic_opset9 import sub, squeeze + from torch.onnx.symbolic_helper import _slice_helper + from torch.onnx import TensorProtoDataType + # batch_indices = rois[:, 0].long() + batch_indices = _slice_helper( + g, rois, axes=[1], starts=[0], ends=[1]) + batch_indices = squeeze(g, batch_indices, 1) + batch_indices = g.op( + 'Cast', batch_indices, to_i=TensorProtoDataType.INT64) + # rois = rois[:, 1:] + rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) + if aligned: + # rois -= 0.5/spatial_scale + aligned_offset = g.op( + 'Constant', + value_t=torch.tensor([0.5 / spatial_scale], + dtype=torch.float32)) + rois = sub(g, rois, aligned_offset) + # roi align + return g.op( + 'RoiAlign', + input, + rois, + batch_indices, + output_height_i=output_size[0], + output_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_i=max(0, sampling_ratio), + mode_s=pool_mode) + + @staticmethod + def forward(ctx, + input, + rois, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + pool_mode='avg', + aligned=True): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + assert pool_mode in ('max', 'avg') + ctx.pool_mode = 0 if pool_mode == 'max' else 1 + ctx.aligned = aligned + ctx.input_shape = input.size() + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + if ctx.pool_mode == 0: + argmax_y = input.new_zeros(output_shape) + argmax_x = input.new_zeros(output_shape) + else: + argmax_y = input.new_zeros(0) + argmax_x = input.new_zeros(0) + + ext_module.roi_align_forward( + input, + rois, + output, + argmax_y, + argmax_x, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + pool_mode=ctx.pool_mode, + aligned=ctx.aligned) + + ctx.save_for_backward(rois, argmax_y, argmax_x) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax_y, argmax_x = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + # complex head architecture may cause grad_output uncontiguous. + grad_output = grad_output.contiguous() + ext_module.roi_align_backward( + grad_output, + rois, + argmax_y, + argmax_x, + grad_input, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + pool_mode=ctx.pool_mode, + aligned=ctx.aligned) + return grad_input, None, None, None, None, None, None + + +roi_align = RoIAlignFunction.apply + + +class RoIAlign(nn.Module): + """RoI align pooling layer. + + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each + output sample. 0 to take samples densely for current models. + pool_mode (str, 'avg' or 'max'): pooling mode in each bin. + aligned (bool): if False, use the legacy implementation in + MMDetection. If True, align the results more perfectly. + use_torchvision (bool): whether to use roi_align from torchvision. + + Note: + The implementation of RoIAlign when aligned=True is modified from + https://github.com/facebookresearch/detectron2/ + + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel + indices (in our pixel model) are computed by floor(c - 0.5) and + ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete + indices [0] and [1] (which are sampled from the underlying signal + at continuous coordinates 0.5 and 1.5). But the original roi_align + (aligned=False) does not subtract the 0.5 when computing + neighboring pixel indices and therefore it uses pixels with a + slightly incorrect alignment (relative to our pixel model) when + performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; + + The difference does not make a difference to the model's + performance if ROIAlign is used together with conv layers. + """ + + @deprecated_api_warning( + { + 'out_size': 'output_size', + 'sample_num': 'sampling_ratio' + }, + cls_name='RoIAlign') + def __init__(self, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + pool_mode='avg', + aligned=True, + use_torchvision=False): + super(RoIAlign, self).__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.pool_mode = pool_mode + self.aligned = aligned + self.use_torchvision = use_torchvision + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx5 boxes. First column is the index into N.\ + The other 4 columns are xyxy. + """ + if self.use_torchvision: + from torchvision.ops import roi_align as tv_roi_align + if 'aligned' in tv_roi_align.__code__.co_varnames: + return tv_roi_align(input, rois, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.aligned) + else: + if self.aligned: + rois -= rois.new_tensor([0.] + + [0.5 / self.spatial_scale] * 4) + return tv_roi_align(input, rois, self.output_size, + self.spatial_scale, self.sampling_ratio) + else: + return roi_align(input, rois, self.output_size, self.spatial_scale, + self.sampling_ratio, self.pool_mode, self.aligned) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale}, ' + s += f'sampling_ratio={self.sampling_ratio}, ' + s += f'pool_mode={self.pool_mode}, ' + s += f'aligned={self.aligned}, ' + s += f'use_torchvision={self.use_torchvision})' + return s diff --git a/annotator/uniformer/mmcv/ops/roi_align_rotated.py b/annotator/uniformer/mmcv/ops/roi_align_rotated.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roi_align_rotated.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward']) + + +class RoIAlignRotatedFunction(Function): + + @staticmethod + def symbolic(g, features, rois, out_size, spatial_scale, sample_num, + aligned, clockwise): + if isinstance(out_size, int): + out_h = out_size + out_w = out_size + elif isinstance(out_size, tuple): + assert len(out_size) == 2 + assert isinstance(out_size[0], int) + assert isinstance(out_size[1], int) + out_h, out_w = out_size + else: + raise TypeError( + '"out_size" must be an integer or tuple of integers') + return g.op( + 'mmcv::MMCVRoIAlignRotated', + features, + rois, + output_height_i=out_h, + output_width_i=out_h, + spatial_scale_f=spatial_scale, + sampling_ratio_i=sample_num, + aligned_i=aligned, + clockwise_i=clockwise) + + @staticmethod + def forward(ctx, + features, + rois, + out_size, + spatial_scale, + sample_num=0, + aligned=True, + clockwise=False): + if isinstance(out_size, int): + out_h = out_size + out_w = out_size + elif isinstance(out_size, tuple): + assert len(out_size) == 2 + assert isinstance(out_size[0], int) + assert isinstance(out_size[1], int) + out_h, out_w = out_size + else: + raise TypeError( + '"out_size" must be an integer or tuple of integers') + ctx.spatial_scale = spatial_scale + ctx.sample_num = sample_num + ctx.aligned = aligned + ctx.clockwise = clockwise + ctx.save_for_backward(rois) + ctx.feature_size = features.size() + + batch_size, num_channels, data_height, data_width = features.size() + num_rois = rois.size(0) + + output = features.new_zeros(num_rois, num_channels, out_h, out_w) + ext_module.roi_align_rotated_forward( + features, + rois, + output, + pooled_height=out_h, + pooled_width=out_w, + spatial_scale=spatial_scale, + sample_num=sample_num, + aligned=aligned, + clockwise=clockwise) + return output + + @staticmethod + def backward(ctx, grad_output): + feature_size = ctx.feature_size + spatial_scale = ctx.spatial_scale + aligned = ctx.aligned + clockwise = ctx.clockwise + sample_num = ctx.sample_num + rois = ctx.saved_tensors[0] + assert feature_size is not None + batch_size, num_channels, data_height, data_width = feature_size + + out_w = grad_output.size(3) + out_h = grad_output.size(2) + + grad_input = grad_rois = None + + if ctx.needs_input_grad[0]: + grad_input = rois.new_zeros(batch_size, num_channels, data_height, + data_width) + ext_module.roi_align_rotated_backward( + grad_output.contiguous(), + rois, + grad_input, + pooled_height=out_h, + pooled_width=out_w, + spatial_scale=spatial_scale, + sample_num=sample_num, + aligned=aligned, + clockwise=clockwise) + return grad_input, grad_rois, None, None, None, None, None + + +roi_align_rotated = RoIAlignRotatedFunction.apply + + +class RoIAlignRotated(nn.Module): + """RoI align pooling layer for rotated proposals. + + It accepts a feature map of shape (N, C, H, W) and rois with shape + (n, 6) with each roi decoded as (batch_index, center_x, center_y, + w, h, angle). The angle is in radian. + + Args: + out_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sample_num (int): number of inputs samples to take for each + output sample. 0 to take samples densely for current models. + aligned (bool): if False, use the legacy implementation in + MMDetection. If True, align the results more perfectly. + Default: True. + clockwise (bool): If True, the angle in each proposal follows a + clockwise fashion in image space, otherwise, the angle is + counterclockwise. Default: False. + + Note: + The implementation of RoIAlign when aligned=True is modified from + https://github.com/facebookresearch/detectron2/ + + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel + indices (in our pixel model) are computed by floor(c - 0.5) and + ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete + indices [0] and [1] (which are sampled from the underlying signal + at continuous coordinates 0.5 and 1.5). But the original roi_align + (aligned=False) does not subtract the 0.5 when computing + neighboring pixel indices and therefore it uses pixels with a + slightly incorrect alignment (relative to our pixel model) when + performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; + + The difference does not make a difference to the model's + performance if ROIAlign is used together with conv layers. + """ + + def __init__(self, + out_size, + spatial_scale, + sample_num=0, + aligned=True, + clockwise=False): + super(RoIAlignRotated, self).__init__() + + self.out_size = out_size + self.spatial_scale = float(spatial_scale) + self.sample_num = int(sample_num) + self.aligned = aligned + self.clockwise = clockwise + + def forward(self, features, rois): + return RoIAlignRotatedFunction.apply(features, rois, self.out_size, + self.spatial_scale, + self.sample_num, self.aligned, + self.clockwise) diff --git a/annotator/uniformer/mmcv/ops/roi_pool.py b/annotator/uniformer/mmcv/ops/roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roi_pool.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['roi_pool_forward', 'roi_pool_backward']) + + +class RoIPoolFunction(Function): + + @staticmethod + def symbolic(g, input, rois, output_size, spatial_scale): + return g.op( + 'MaxRoiPool', + input, + rois, + pooled_shape_i=output_size, + spatial_scale_f=spatial_scale) + + @staticmethod + def forward(ctx, input, rois, output_size, spatial_scale=1.0): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + argmax = input.new_zeros(output_shape, dtype=torch.int) + + ext_module.roi_pool_forward( + input, + rois, + output, + argmax, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale) + + ctx.save_for_backward(rois, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + + ext_module.roi_pool_backward( + grad_output, + rois, + argmax, + grad_input, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale) + + return grad_input, None, None, None + + +roi_pool = RoIPoolFunction.apply + + +class RoIPool(nn.Module): + + def __init__(self, output_size, spatial_scale=1.0): + super(RoIPool, self).__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale})' + return s diff --git a/annotator/uniformer/mmcv/ops/roiaware_pool3d.py b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py new file mode 100644 index 0000000000000000000000000000000000000000..291b0e5a9b692492c7d7e495ea639c46042e2f18 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn as nn +from torch.autograd import Function + +import annotator.uniformer.mmcv as mmcv +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward']) + + +class RoIAwarePool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. + + Please refer to `PartA2 `_ for more + details. + + Args: + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int, optional): The maximum number of points per + voxel. Default: 128. + mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'. + Default: 'max'. + """ + + def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): + super().__init__() + + self.out_size = out_size + self.max_pts_per_voxel = max_pts_per_voxel + assert mode in ['max', 'avg'] + pool_mapping = {'max': 0, 'avg': 1} + self.mode = pool_mapping[mode] + + def forward(self, rois, pts, pts_feature): + """ + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + """ + + return RoIAwarePool3dFunction.apply(rois, pts, pts_feature, + self.out_size, + self.max_pts_per_voxel, self.mode) + + +class RoIAwarePool3dFunction(Function): + + @staticmethod + def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, + mode): + """ + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int): The maximum number of points per voxel. + Default: 128. + mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average + pool). + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output + pooled features. + """ + + if isinstance(out_size, int): + out_x = out_y = out_z = out_size + else: + assert len(out_size) == 3 + assert mmcv.is_tuple_of(out_size, int) + out_x, out_y, out_z = out_size + + num_rois = rois.shape[0] + num_channels = pts_feature.shape[-1] + num_pts = pts.shape[0] + + pooled_features = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels)) + argmax = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int) + pts_idx_of_voxels = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, max_pts_per_voxel), + dtype=torch.int) + + ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax, + pts_idx_of_voxels, pooled_features, + mode) + + ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, + num_pts, num_channels) + return pooled_features + + @staticmethod + def backward(ctx, grad_out): + ret = ctx.roiaware_pool3d_for_backward + pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret + + grad_in = grad_out.new_zeros((num_pts, num_channels)) + ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax, + grad_out.contiguous(), grad_in, + mode) + + return None, None, grad_in, None, None, None diff --git a/annotator/uniformer/mmcv/ops/roipoint_pool3d.py b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py new file mode 100644 index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py @@ -0,0 +1,77 @@ +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward']) + + +class RoIPointPool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. + + Please refer to `Paper of PartA2 `_ + for more details. + + Args: + num_sampled_points (int, optional): Number of samples in each roi. + Default: 512. + """ + + def __init__(self, num_sampled_points=512): + super().__init__() + self.num_sampled_points = num_sampled_points + + def forward(self, points, point_features, boxes3d): + """ + Args: + points (torch.Tensor): Input points whose shape is (B, N, C). + point_features (torch.Tensor): Features of input points whose shape + is (B, N, C). + boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). + + Returns: + pooled_features (torch.Tensor): The output pooled features whose + shape is (B, M, 512, 3 + C). + pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). + """ + return RoIPointPool3dFunction.apply(points, point_features, boxes3d, + self.num_sampled_points) + + +class RoIPointPool3dFunction(Function): + + @staticmethod + def forward(ctx, points, point_features, boxes3d, num_sampled_points=512): + """ + Args: + points (torch.Tensor): Input points whose shape is (B, N, C). + point_features (torch.Tensor): Features of input points whose shape + is (B, N, C). + boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). + num_sampled_points (int, optional): The num of sampled points. + Default: 512. + + Returns: + pooled_features (torch.Tensor): The output pooled features whose + shape is (B, M, 512, 3 + C). + pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). + """ + assert len(points.shape) == 3 and points.shape[2] == 3 + batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[ + 1], point_features.shape[2] + pooled_boxes3d = boxes3d.view(batch_size, -1, 7) + pooled_features = point_features.new_zeros( + (batch_size, boxes_num, num_sampled_points, 3 + feature_len)) + pooled_empty_flag = point_features.new_zeros( + (batch_size, boxes_num)).int() + + ext_module.roipoint_pool3d_forward(points.contiguous(), + pooled_boxes3d.contiguous(), + point_features.contiguous(), + pooled_features, pooled_empty_flag) + + return pooled_features, pooled_empty_flag + + @staticmethod + def backward(ctx, grad_out): + raise NotImplementedError diff --git a/annotator/uniformer/mmcv/ops/saconv.py b/annotator/uniformer/mmcv/ops/saconv.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee3978e097fca422805db4e31ae481006d7971 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/saconv.py @@ -0,0 +1,145 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init +from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version + + +@CONV_LAYERS.register_module(name='SAC') +class SAConv2d(ConvAWS2d): + """SAC (Switchable Atrous Convolution) + + This is an implementation of SAC in DetectoRS + (https://arxiv.org/pdf/2006.02334.pdf). + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + use_deform: If ``True``, replace convolution with deformable + convolution. Default: ``False``. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + use_deform=False): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.use_deform = use_deform + self.switch = nn.Conv2d( + self.in_channels, 1, kernel_size=1, stride=stride, bias=True) + self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) + self.pre_context = nn.Conv2d( + self.in_channels, self.in_channels, kernel_size=1, bias=True) + self.post_context = nn.Conv2d( + self.out_channels, self.out_channels, kernel_size=1, bias=True) + if self.use_deform: + self.offset_s = nn.Conv2d( + self.in_channels, + 18, + kernel_size=3, + padding=1, + stride=stride, + bias=True) + self.offset_l = nn.Conv2d( + self.in_channels, + 18, + kernel_size=3, + padding=1, + stride=stride, + bias=True) + self.init_weights() + + def init_weights(self): + constant_init(self.switch, 0, bias=1) + self.weight_diff.data.zero_() + constant_init(self.pre_context, 0) + constant_init(self.post_context, 0) + if self.use_deform: + constant_init(self.offset_s, 0) + constant_init(self.offset_l, 0) + + def forward(self, x): + # pre-context + avg_x = F.adaptive_avg_pool2d(x, output_size=1) + avg_x = self.pre_context(avg_x) + avg_x = avg_x.expand_as(x) + x = x + avg_x + # switch + avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') + avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) + switch = self.switch(avg_x) + # sac + weight = self._get_weight(self.weight) + zero_bias = torch.zeros( + self.out_channels, device=weight.device, dtype=weight.dtype) + + if self.use_deform: + offset = self.offset_s(avg_x) + out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, + self.dilation, self.groups, 1) + else: + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.5.0')): + out_s = super().conv2d_forward(x, weight) + elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + # bias is a required argument of _conv_forward in torch 1.8.0 + out_s = super()._conv_forward(x, weight, zero_bias) + else: + out_s = super()._conv_forward(x, weight) + ori_p = self.padding + ori_d = self.dilation + self.padding = tuple(3 * p for p in self.padding) + self.dilation = tuple(3 * d for d in self.dilation) + weight = weight + self.weight_diff + if self.use_deform: + offset = self.offset_l(avg_x) + out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, + self.dilation, self.groups, 1) + else: + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.5.0')): + out_l = super().conv2d_forward(x, weight) + elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + # bias is a required argument of _conv_forward in torch 1.8.0 + out_l = super()._conv_forward(x, weight, zero_bias) + else: + out_l = super()._conv_forward(x, weight) + + out = switch * out_s + (1 - switch) * out_l + self.padding = ori_p + self.dilation = ori_d + # post-context + avg_x = F.adaptive_avg_pool2d(out, output_size=1) + avg_x = self.post_context(avg_x) + avg_x = avg_x.expand_as(out) + out = out + avg_x + return out diff --git a/annotator/uniformer/mmcv/ops/scatter_points.py b/annotator/uniformer/mmcv/ops/scatter_points.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/scatter_points.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', + ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward']) + + +class _DynamicScatter(Function): + + @staticmethod + def forward(ctx, feats, coors, reduce_type='max'): + """convert kitti points(N, >=3) to voxels. + + Args: + feats (torch.Tensor): [N, C]. Points features to be reduced + into voxels. + coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates + (specifically multi-dim voxel index) of each points. + reduce_type (str, optional): Reduce op. support 'max', 'sum' and + 'mean'. Default: 'max'. + + Returns: + voxel_feats (torch.Tensor): [M, C]. Reduced features, input + features that shares the same voxel coordinates are reduced to + one row. + voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates. + """ + results = ext_module.dynamic_point_to_voxel_forward( + feats, coors, reduce_type) + (voxel_feats, voxel_coors, point2voxel_map, + voxel_points_count) = results + ctx.reduce_type = reduce_type + ctx.save_for_backward(feats, voxel_feats, point2voxel_map, + voxel_points_count) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + + @staticmethod + def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): + (feats, voxel_feats, point2voxel_map, + voxel_points_count) = ctx.saved_tensors + grad_feats = torch.zeros_like(feats) + # TODO: whether to use index put or use cuda_backward + # To use index put, need point to voxel index + ext_module.dynamic_point_to_voxel_backward( + grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats, + point2voxel_map, voxel_points_count, ctx.reduce_type) + return grad_feats, None, None + + +dynamic_scatter = _DynamicScatter.apply + + +class DynamicScatter(nn.Module): + """Scatters points into voxels, used in the voxel encoder with dynamic + voxelization. + + Note: + The CPU and GPU implementation get the same output, but have numerical + difference after summation and division (e.g., 5e-7). + + Args: + voxel_size (list): list [x, y, z] size of three dimension. + point_cloud_range (list): The coordinate range of points, [x_min, + y_min, z_min, x_max, y_max, z_max]. + average_points (bool): whether to use avg pooling to scatter points + into voxel. + """ + + def __init__(self, voxel_size, point_cloud_range, average_points: bool): + super().__init__() + + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.average_points = average_points + + def forward_single(self, points, coors): + """Scatters points into voxels. + + Args: + points (torch.Tensor): Points to be reduced into voxels. + coors (torch.Tensor): Corresponding voxel coordinates (specifically + multi-dim voxel index) of each points. + + Returns: + voxel_feats (torch.Tensor): Reduced features, input features that + shares the same voxel coordinates are reduced to one row. + voxel_coors (torch.Tensor): Voxel coordinates. + """ + reduce = 'mean' if self.average_points else 'max' + return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) + + def forward(self, points, coors): + """Scatters points/features into voxels. + + Args: + points (torch.Tensor): Points to be reduced into voxels. + coors (torch.Tensor): Corresponding voxel coordinates (specifically + multi-dim voxel index) of each points. + + Returns: + voxel_feats (torch.Tensor): Reduced features, input features that + shares the same voxel coordinates are reduced to one row. + voxel_coors (torch.Tensor): Voxel coordinates. + """ + if coors.size(-1) == 3: + return self.forward_single(points, coors) + else: + batch_size = coors[-1, 0] + 1 + voxels, voxel_coors = [], [] + for i in range(batch_size): + inds = torch.where(coors[:, 0] == i) + voxel, voxel_coor = self.forward_single( + points[inds], coors[inds][:, 1:]) + coor_pad = nn.functional.pad( + voxel_coor, (1, 0), mode='constant', value=i) + voxel_coors.append(coor_pad) + voxels.append(voxel) + features = torch.cat(voxels, dim=0) + feature_coors = torch.cat(voxel_coors, dim=0) + + return features, feature_coors + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += 'voxel_size=' + str(self.voxel_size) + s += ', point_cloud_range=' + str(self.point_cloud_range) + s += ', average_points=' + str(self.average_points) + s += ')' + return s diff --git a/annotator/uniformer/mmcv/ops/sync_bn.py b/annotator/uniformer/mmcv/ops/sync_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b016fcbe860989c56cd1040034bcfa60e146d2 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/sync_bn.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.module import Module +from torch.nn.parameter import Parameter + +from annotator.uniformer.mmcv.cnn import NORM_LAYERS +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output', + 'sync_bn_backward_param', 'sync_bn_backward_data' +]) + + +class SyncBatchNormFunction(Function): + + @staticmethod + def symbolic(g, input, running_mean, running_var, weight, bias, momentum, + eps, group, group_size, stats_mode): + return g.op( + 'mmcv::MMCVSyncBatchNorm', + input, + running_mean, + running_var, + weight, + bias, + momentum_f=momentum, + eps_f=eps, + group_i=group, + group_size_i=group_size, + stats_mode=stats_mode) + + @staticmethod + def forward(self, input, running_mean, running_var, weight, bias, momentum, + eps, group, group_size, stats_mode): + self.momentum = momentum + self.eps = eps + self.group = group + self.group_size = group_size + self.stats_mode = stats_mode + + assert isinstance( + input, (torch.HalfTensor, torch.FloatTensor, + torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ + f'only support Half or Float Tensor, but {input.type()}' + output = torch.zeros_like(input) + input3d = input.flatten(start_dim=2) + output3d = output.view_as(input3d) + num_channels = input3d.size(1) + + # ensure mean/var/norm/std are initialized as zeros + # ``torch.empty()`` does not guarantee that + mean = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + var = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + norm = torch.zeros_like( + input3d, dtype=torch.float, device=input3d.device) + std = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + + batch_size = input3d.size(0) + if batch_size > 0: + ext_module.sync_bn_forward_mean(input3d, mean) + batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype) + else: + # skip updating mean and leave it as zeros when the input is empty + batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype) + + # synchronize mean and the batch flag + vec = torch.cat([mean, batch_flag]) + if self.stats_mode == 'N': + vec *= batch_size + if self.group_size > 1: + dist.all_reduce(vec, group=self.group) + total_batch = vec[-1].detach() + mean = vec[:num_channels] + + if self.stats_mode == 'default': + mean = mean / self.group_size + elif self.stats_mode == 'N': + mean = mean / total_batch.clamp(min=1) + else: + raise NotImplementedError + + # leave var as zeros when the input is empty + if batch_size > 0: + ext_module.sync_bn_forward_var(input3d, mean, var) + + if self.stats_mode == 'N': + var *= batch_size + if self.group_size > 1: + dist.all_reduce(var, group=self.group) + + if self.stats_mode == 'default': + var /= self.group_size + elif self.stats_mode == 'N': + var /= total_batch.clamp(min=1) + else: + raise NotImplementedError + + # if the total batch size over all the ranks is zero, + # we should not update the statistics in the current batch + update_flag = total_batch.clamp(max=1) + momentum = update_flag * self.momentum + ext_module.sync_bn_forward_output( + input3d, + mean, + var, + weight, + bias, + running_mean, + running_var, + norm, + std, + output3d, + eps=self.eps, + momentum=momentum, + group_size=self.group_size) + self.save_for_backward(norm, std, weight) + return output + + @staticmethod + @once_differentiable + def backward(self, grad_output): + norm, std, weight = self.saved_tensors + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(weight) + grad_input = torch.zeros_like(grad_output) + grad_output3d = grad_output.flatten(start_dim=2) + grad_input3d = grad_input.view_as(grad_output3d) + + batch_size = grad_input3d.size(0) + if batch_size > 0: + ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight, + grad_bias) + + # all reduce + if self.group_size > 1: + dist.all_reduce(grad_weight, group=self.group) + dist.all_reduce(grad_bias, group=self.group) + grad_weight /= self.group_size + grad_bias /= self.group_size + + if batch_size > 0: + ext_module.sync_bn_backward_data(grad_output3d, weight, + grad_weight, grad_bias, norm, std, + grad_input3d) + + return grad_input, None, None, grad_weight, grad_bias, \ + None, None, None, None, None + + +@NORM_LAYERS.register_module(name='MMSyncBN') +class SyncBatchNorm(Module): + """Synchronized Batch Normalization. + + Args: + num_features (int): number of features/chennels in input tensor + eps (float, optional): a value added to the denominator for numerical + stability. Defaults to 1e-5. + momentum (float, optional): the value used for the running_mean and + running_var computation. Defaults to 0.1. + affine (bool, optional): whether to use learnable affine parameters. + Defaults to True. + track_running_stats (bool, optional): whether to track the running + mean and variance during training. When set to False, this + module does not track such statistics, and initializes statistics + buffers ``running_mean`` and ``running_var`` as ``None``. When + these buffers are ``None``, this module always uses batch + statistics in both training and eval modes. Defaults to True. + group (int, optional): synchronization of stats happen within + each process group individually. By default it is synchronization + across the whole world. Defaults to None. + stats_mode (str, optional): The statistical mode. Available options + includes ``'default'`` and ``'N'``. Defaults to 'default'. + When ``stats_mode=='default'``, it computes the overall statistics + using those from each worker with equal weight, i.e., the + statistics are synchronized and simply divied by ``group``. This + mode will produce inaccurate statistics when empty tensors occur. + When ``stats_mode=='N'``, it compute the overall statistics using + the total number of batches in each worker ignoring the number of + group, i.e., the statistics are synchronized and then divied by + the total batch ``N``. This mode is beneficial when empty tensors + occur during training, as it average the total mean by the real + number of batch. + """ + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + group=None, + stats_mode='default'): + super(SyncBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + group = dist.group.WORLD if group is None else group + self.group = group + self.group_size = dist.get_world_size(group) + assert stats_mode in ['default', 'N'], \ + f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"' + self.stats_mode = stats_mode + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.register_buffer('num_batches_tracked', + torch.tensor(0, dtype=torch.long)) + else: + self.register_buffer('running_mean', None) + self.register_buffer('running_var', None) + self.register_buffer('num_batches_tracked', None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + self.weight.data.uniform_() # pytorch use ones_() + self.bias.data.zero_() + + def forward(self, input): + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input, got {input.dim()}D input') + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training or not self.track_running_stats: + return SyncBatchNormFunction.apply( + input, self.running_mean, self.running_var, self.weight, + self.bias, exponential_average_factor, self.eps, self.group, + self.group_size, self.stats_mode) + else: + return F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, False, + exponential_average_factor, self.eps) + + def __repr__(self): + s = self.__class__.__name__ + s += f'({self.num_features}, ' + s += f'eps={self.eps}, ' + s += f'momentum={self.momentum}, ' + s += f'affine={self.affine}, ' + s += f'track_running_stats={self.track_running_stats}, ' + s += f'group_size={self.group_size},' + s += f'stats_mode={self.stats_mode})' + return s diff --git a/annotator/uniformer/mmcv/ops/three_interpolate.py b/annotator/uniformer/mmcv/ops/three_interpolate.py new file mode 100644 index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a --- /dev/null +++ b/annotator/uniformer/mmcv/ops/three_interpolate.py @@ -0,0 +1,68 @@ +from typing import Tuple + +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['three_interpolate_forward', 'three_interpolate_backward']) + + +class ThreeInterpolate(Function): + """Performs weighted linear interpolation on 3 features. + + Please refer to `Paper of PointNet++ `_ + for more details. + """ + + @staticmethod + def forward(ctx, features: torch.Tensor, indices: torch.Tensor, + weight: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, M) Features descriptors to be + interpolated + indices (Tensor): (B, n, 3) index three nearest neighbors + of the target features in features + weight (Tensor): (B, n, 3) weights of interpolation + + Returns: + Tensor: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert indices.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = indices.size(1) + ctx.three_interpolate_for_backward = (indices, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + ext_module.three_interpolate_forward( + features, indices, weight, output, b=B, c=c, m=m, n=n) + return output + + @staticmethod + def backward( + ctx, grad_out: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + grad_out (Tensor): (B, C, N) tensor with gradients of outputs + + Returns: + Tensor: (B, C, M) tensor with gradients of features + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = torch.cuda.FloatTensor(B, c, m).zero_() + grad_out_data = grad_out.data.contiguous() + + ext_module.three_interpolate_backward( + grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply diff --git a/annotator/uniformer/mmcv/ops/three_nn.py b/annotator/uniformer/mmcv/ops/three_nn.py new file mode 100644 index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/three_nn.py @@ -0,0 +1,51 @@ +from typing import Tuple + +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['three_nn_forward']) + + +class ThreeNN(Function): + """Find the top-3 nearest neighbors of the target set from the source set. + + Please refer to `Paper of PointNet++ `_ + for more details. + """ + + @staticmethod + def forward(ctx, target: torch.Tensor, + source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + target (Tensor): shape (B, N, 3), points set that needs to + find the nearest neighbors. + source (Tensor): shape (B, M, 3), points set that is used + to find the nearest neighbors of points in target set. + + Returns: + Tensor: shape (B, N, 3), L2 distance of each point in target + set to their corresponding nearest neighbors. + """ + target = target.contiguous() + source = source.contiguous() + + B, N, _ = target.size() + m = source.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) + + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply diff --git a/annotator/uniformer/mmcv/ops/tin_shift.py b/annotator/uniformer/mmcv/ops/tin_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e --- /dev/null +++ b/annotator/uniformer/mmcv/ops/tin_shift.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Code reference from "Temporal Interlacing Network" +# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py +# Hao Shao, Shengju Qian, Yu Liu +# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk + +import torch +import torch.nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['tin_shift_forward', 'tin_shift_backward']) + + +class TINShiftFunction(Function): + + @staticmethod + def forward(ctx, input, shift): + C = input.size(2) + num_segments = shift.size(1) + if C // num_segments <= 0 or C % num_segments != 0: + raise ValueError('C should be a multiple of num_segments, ' + f'but got C={C} and num_segments={num_segments}.') + + ctx.save_for_backward(shift) + + out = torch.zeros_like(input) + ext_module.tin_shift_forward(input, shift, out) + + return out + + @staticmethod + def backward(ctx, grad_output): + + shift = ctx.saved_tensors[0] + data_grad_input = grad_output.new(*grad_output.size()).zero_() + shift_grad_input = shift.new(*shift.size()).zero_() + ext_module.tin_shift_backward(grad_output, shift, data_grad_input) + + return data_grad_input, shift_grad_input + + +tin_shift = TINShiftFunction.apply + + +class TINShift(nn.Module): + """Temporal Interlace Shift. + + Temporal Interlace shift is a differentiable temporal-wise frame shifting + which is proposed in "Temporal Interlacing Network" + + Please refer to https://arxiv.org/abs/2001.06499 for more details. + Code is modified from https://github.com/mit-han-lab/temporal-shift-module + """ + + def forward(self, input, shift): + """Perform temporal interlace shift. + + Args: + input (Tensor): Feature map with shape [N, num_segments, C, H * W]. + shift (Tensor): Shift tensor with shape [N, num_segments]. + + Returns: + Feature map after temporal interlace shift. + """ + return tin_shift(input, shift) diff --git a/annotator/uniformer/mmcv/ops/upfirdn2d.py b/annotator/uniformer/mmcv/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..c8bb2c3c949eed38a6465ed369fa881538dca010 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/upfirdn2d.py @@ -0,0 +1,330 @@ +# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator +# Augmentation (ADA) +# ======================================================================= + +# 1. Definitions + +# "Licensor" means any person or entity that distributes its Work. + +# "Software" means the original work of authorship made available under +# this License. + +# "Work" means the Software and any additions to or derivative works of +# the Software that are made available under this License. + +# The terms "reproduce," "reproduction," "derivative works," and +# "distribution" have the meaning as provided under U.S. copyright law; +# provided, however, that for the purposes of this License, derivative +# works shall not include works that remain separable from, or merely +# link (or bind by name) to the interfaces of, the Work. + +# Works, including the Software, are "made available" under this License +# by including in or with the Work either (a) a copyright notice +# referencing the applicability of this License to the Work, or (b) a +# copy of this License. + +# 2. License Grants + +# 2.1 Copyright Grant. Subject to the terms and conditions of this +# License, each Licensor grants to you a perpetual, worldwide, +# non-exclusive, royalty-free, copyright license to reproduce, +# prepare derivative works of, publicly display, publicly perform, +# sublicense and distribute its Work and any resulting derivative +# works in any form. + +# 3. Limitations + +# 3.1 Redistribution. You may reproduce or distribute the Work only +# if (a) you do so under this License, (b) you include a complete +# copy of this License with your distribution, and (c) you retain +# without modification any copyright, patent, trademark, or +# attribution notices that are present in the Work. + +# 3.2 Derivative Works. You may specify that additional or different +# terms apply to the use, reproduction, and distribution of your +# derivative works of the Work ("Your Terms") only if (a) Your Terms +# provide that the use limitation in Section 3.3 applies to your +# derivative works, and (b) you identify the specific derivative +# works that are subject to Your Terms. Notwithstanding Your Terms, +# this License (including the redistribution requirements in Section +# 3.1) will continue to apply to the Work itself. + +# 3.3 Use Limitation. The Work and any derivative works thereof only +# may be used or intended for use non-commercially. Notwithstanding +# the foregoing, NVIDIA and its affiliates may use the Work and any +# derivative works commercially. As used herein, "non-commercially" +# means for research or evaluation purposes only. + +# 3.4 Patent Claims. If you bring or threaten to bring a patent claim +# against any Licensor (including any claim, cross-claim or +# counterclaim in a lawsuit) to enforce any patents that you allege +# are infringed by any Work, then your rights under this License from +# such Licensor (including the grant in Section 2.1) will terminate +# immediately. + +# 3.5 Trademarks. This License does not grant any rights to use any +# Licensor’s or its affiliates’ names, logos, or trademarks, except +# as necessary to reproduce the notices described in this License. + +# 3.6 Termination. If you violate any term of this License, then your +# rights under this License (including the grant in Section 2.1) will +# terminate immediately. + +# 4. Disclaimer of Warranty. + +# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +# THIS LICENSE. + +# 5. Limitation of Liability. + +# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGES. + +# ======================================================================= + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +from annotator.uniformer.mmcv.utils import to_2tuple +from ..utils import ext_loader + +upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, + in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + up_x=down_x, + up_y=down_y, + down_x=up_x, + down_y=up_y, + pad_x0=g_pad_x0, + pad_x1=g_pad_x1, + pad_y0=g_pad_y0, + pad_y1=g_pad_y1) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], + in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], + ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + up_x=ctx.up_x, + up_y=ctx.up_y, + down_x=ctx.down_x, + down_y=ctx.down_y, + pad_x0=ctx.pad_x0, + pad_x1=ctx.pad_x1, + pad_y0=ctx.pad_y0, + pad_y1=ctx.pad_y1) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], + ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d( + input, + kernel, + up_x=up_x, + up_y=up_y, + down_x=down_x, + down_y=down_y, + pad_x0=pad_x0, + pad_x1=pad_x1, + pad_y0=pad_y0, + pad_y1=pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + """UpFRIDn for 2d features. + + UpFIRDn is short for upsample, apply FIR filter and downsample. More + details can be found in: + https://www.mathworks.com/help/signal/ref/upfirdn.html + + Args: + input (Tensor): Tensor with shape of (n, c, h, w). + kernel (Tensor): Filter kernel. + up (int | tuple[int], optional): Upsampling factor. If given a number, + we will use this factor for the both height and width side. + Defaults to 1. + down (int | tuple[int], optional): Downsampling factor. If given a + number, we will use this factor for the both height and width side. + Defaults to 1. + pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or + (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0). + + Returns: + Tensor: Tensor after UpFIRDn. + """ + if input.device.type == 'cpu': + if len(pad) == 2: + pad = (pad[0], pad[1], pad[0], pad[1]) + + up = to_2tuple(up) + + down = to_2tuple(down) + + out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1], + pad[0], pad[1], pad[2], pad[3]) + else: + _up = to_2tuple(up) + + _down = to_2tuple(down) + + if len(pad) == 4: + _pad = pad + elif len(pad) == 2: + _pad = (pad[0], pad[1], pad[0], pad[1]) + + out = UpFirDn2d.apply(input, kernel, _up, _down, _pad) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, + [0, 0, + max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/annotator/uniformer/mmcv/ops/voxelize.py b/annotator/uniformer/mmcv/ops/voxelize.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/voxelize.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.autograd import Function +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward']) + + +class _Voxelization(Function): + + @staticmethod + def forward(ctx, + points, + voxel_size, + coors_range, + max_points=35, + max_voxels=20000): + """Convert kitti points(N, >=3) to voxels. + + Args: + points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points + and points[:, 3:] contain other information like reflectivity. + voxel_size (tuple or float): The size of voxel with the shape of + [3]. + coors_range (tuple or float): The coordinate range of voxel with + the shape of [6]. + max_points (int, optional): maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize. Default: 35. + max_voxels (int, optional): maximum voxels this function create. + for second, 20000 is a good choice. Users should shuffle points + before call this function because max_voxels may drop points. + Default: 20000. + + Returns: + voxels_out (torch.Tensor): Output voxels with the shape of [M, + max_points, ndim]. Only contain points and returned when + max_points != -1. + coors_out (torch.Tensor): Output coordinates with the shape of + [M, 3]. + num_points_per_voxel_out (torch.Tensor): Num points per voxel with + the shape of [M]. Only returned when max_points != -1. + """ + if max_points == -1 or max_voxels == -1: + coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int) + ext_module.dynamic_voxelize_forward(points, coors, voxel_size, + coors_range, 3) + return coors + else: + voxels = points.new_zeros( + size=(max_voxels, max_points, points.size(1))) + coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int) + num_points_per_voxel = points.new_zeros( + size=(max_voxels, ), dtype=torch.int) + voxel_num = ext_module.hard_voxelize_forward( + points, voxels, coors, num_points_per_voxel, voxel_size, + coors_range, max_points, max_voxels, 3) + # select the valid voxels + voxels_out = voxels[:voxel_num] + coors_out = coors[:voxel_num] + num_points_per_voxel_out = num_points_per_voxel[:voxel_num] + return voxels_out, coors_out, num_points_per_voxel_out + + +voxelization = _Voxelization.apply + + +class Voxelization(nn.Module): + """Convert kitti points(N, >=3) to voxels. + + Please refer to `PVCNN `_ for more + details. + + Args: + voxel_size (tuple or float): The size of voxel with the shape of [3]. + point_cloud_range (tuple or float): The coordinate range of voxel with + the shape of [6]. + max_num_points (int): maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize. + max_voxels (int, optional): maximum voxels this function create. + for second, 20000 is a good choice. Users should shuffle points + before call this function because max_voxels may drop points. + Default: 20000. + """ + + def __init__(self, + voxel_size, + point_cloud_range, + max_num_points, + max_voxels=20000): + super().__init__() + + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.max_num_points = max_num_points + if isinstance(max_voxels, tuple): + self.max_voxels = max_voxels + else: + self.max_voxels = _pair(max_voxels) + + point_cloud_range = torch.tensor( + point_cloud_range, dtype=torch.float32) + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + grid_size = (point_cloud_range[3:] - + point_cloud_range[:3]) / voxel_size + grid_size = torch.round(grid_size).long() + input_feat_shape = grid_size[:2] + self.grid_size = grid_size + # the origin shape is as [x-len, y-len, z-len] + # [w, h, d] -> [d, h, w] + self.pcd_shape = [*input_feat_shape, 1][::-1] + + def forward(self, input): + if self.training: + max_voxels = self.max_voxels[0] + else: + max_voxels = self.max_voxels[1] + + return voxelization(input, self.voxel_size, self.point_cloud_range, + self.max_num_points, max_voxels) + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += 'voxel_size=' + str(self.voxel_size) + s += ', point_cloud_range=' + str(self.point_cloud_range) + s += ', max_num_points=' + str(self.max_num_points) + s += ', max_voxels=' + str(self.max_voxels) + s += ')' + return s diff --git a/annotator/uniformer/mmcv/parallel/__init__.py b/annotator/uniformer/mmcv/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .collate import collate +from .data_container import DataContainer +from .data_parallel import MMDataParallel +from .distributed import MMDistributedDataParallel +from .registry import MODULE_WRAPPERS +from .scatter_gather import scatter, scatter_kwargs +from .utils import is_module_wrapper + +__all__ = [ + 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel', + 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS' +] diff --git a/annotator/uniformer/mmcv/parallel/_functions.py b/annotator/uniformer/mmcv/parallel/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/_functions.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel._functions import _get_stream + + +def scatter(input, devices, streams=None): + """Scatters tensor across multiple GPUs.""" + if streams is None: + streams = [None] * len(devices) + + if isinstance(input, list): + chunk_size = (len(input) - 1) // len(devices) + 1 + outputs = [ + scatter(input[i], [devices[i // chunk_size]], + [streams[i // chunk_size]]) for i in range(len(input)) + ] + return outputs + elif isinstance(input, torch.Tensor): + output = input.contiguous() + # TODO: copy to a pinned buffer first (if copying from CPU) + stream = streams[0] if output.numel() > 0 else None + if devices != [-1]: + with torch.cuda.device(devices[0]), torch.cuda.stream(stream): + output = output.cuda(devices[0], non_blocking=True) + else: + # unsqueeze the first dimension thus the tensor's shape is the + # same as those scattered with GPU. + output = output.unsqueeze(0) + return output + else: + raise Exception(f'Unknown type {type(input)}.') + + +def synchronize_stream(output, devices, streams): + if isinstance(output, list): + chunk_size = len(output) // len(devices) + for i in range(len(devices)): + for j in range(chunk_size): + synchronize_stream(output[i * chunk_size + j], [devices[i]], + [streams[i]]) + elif isinstance(output, torch.Tensor): + if output.numel() != 0: + with torch.cuda.device(devices[0]): + main_stream = torch.cuda.current_stream() + main_stream.wait_stream(streams[0]) + output.record_stream(main_stream) + else: + raise Exception(f'Unknown type {type(output)}.') + + +def get_input_device(input): + if isinstance(input, list): + for item in input: + input_device = get_input_device(item) + if input_device != -1: + return input_device + return -1 + elif isinstance(input, torch.Tensor): + return input.get_device() if input.is_cuda else -1 + else: + raise Exception(f'Unknown type {type(input)}.') + + +class Scatter: + + @staticmethod + def forward(target_gpus, input): + input_device = get_input_device(input) + streams = None + if input_device == -1 and target_gpus != [-1]: + # Perform CPU to GPU copies in a background stream + streams = [_get_stream(device) for device in target_gpus] + + outputs = scatter(input, target_gpus, streams) + # Synchronize with the copy stream + if streams is not None: + synchronize_stream(outputs, target_gpus, streams) + + return tuple(outputs) diff --git a/annotator/uniformer/mmcv/parallel/collate.py b/annotator/uniformer/mmcv/parallel/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/collate.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Mapping, Sequence + +import torch +import torch.nn.functional as F +from torch.utils.data.dataloader import default_collate + +from .data_container import DataContainer + + +def collate(batch, samples_per_gpu=1): + """Puts each data field into a tensor/DataContainer with outer dimension + batch size. + + Extend default_collate to add support for + :type:`~mmcv.parallel.DataContainer`. There are 3 cases. + + 1. cpu_only = True, e.g., meta data + 2. cpu_only = False, stack = True, e.g., images tensors + 3. cpu_only = False, stack = False, e.g., gt bboxes + """ + + if not isinstance(batch, Sequence): + raise TypeError(f'{batch.dtype} is not supported.') + + if isinstance(batch[0], DataContainer): + stacked = [] + if batch[0].cpu_only: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer( + stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) + elif batch[0].stack: + for i in range(0, len(batch), samples_per_gpu): + assert isinstance(batch[i].data, torch.Tensor) + + if batch[i].pad_dims is not None: + ndim = batch[i].dim() + assert ndim > batch[i].pad_dims + max_shape = [0 for _ in range(batch[i].pad_dims)] + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = batch[i].size(-dim) + for sample in batch[i:i + samples_per_gpu]: + for dim in range(0, ndim - batch[i].pad_dims): + assert batch[i].size(dim) == sample.size(dim) + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = max(max_shape[dim - 1], + sample.size(-dim)) + padded_samples = [] + for sample in batch[i:i + samples_per_gpu]: + pad = [0 for _ in range(batch[i].pad_dims * 2)] + for dim in range(1, batch[i].pad_dims + 1): + pad[2 * dim - + 1] = max_shape[dim - 1] - sample.size(-dim) + padded_samples.append( + F.pad( + sample.data, pad, value=sample.padding_value)) + stacked.append(default_collate(padded_samples)) + elif batch[i].pad_dims is None: + stacked.append( + default_collate([ + sample.data + for sample in batch[i:i + samples_per_gpu] + ])) + else: + raise ValueError( + 'pad_dims should be either None or integers (1-3)') + + else: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer(stacked, batch[0].stack, batch[0].padding_value) + elif isinstance(batch[0], Sequence): + transposed = zip(*batch) + return [collate(samples, samples_per_gpu) for samples in transposed] + elif isinstance(batch[0], Mapping): + return { + key: collate([d[key] for d in batch], samples_per_gpu) + for key in batch[0] + } + else: + return default_collate(batch) diff --git a/annotator/uniformer/mmcv/parallel/data_container.py b/annotator/uniformer/mmcv/parallel/data_container.py new file mode 100644 index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/data_container.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch + + +def assert_tensor_type(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not isinstance(args[0].data, torch.Tensor): + raise AttributeError( + f'{args[0].__class__.__name__} has no attribute ' + f'{func.__name__} for type {args[0].datatype}') + return func(*args, **kwargs) + + return wrapper + + +class DataContainer: + """A container for any type of objects. + + Typically tensors will be stacked in the collate function and sliced along + some dimension in the scatter function. This behavior has some limitations. + 1. All tensors have to be the same size. + 2. Types are limited (numpy array or Tensor). + + We design `DataContainer` and `MMDataParallel` to overcome these + limitations. The behavior can be either of the following. + + - copy to GPU, pad all tensors to the same size and stack them + - copy to GPU without stacking + - leave the objects as is and pass it to the model + - pad_dims specifies the number of last few dimensions to do padding + """ + + def __init__(self, + data, + stack=False, + padding_value=0, + cpu_only=False, + pad_dims=2): + self._data = data + self._cpu_only = cpu_only + self._stack = stack + self._padding_value = padding_value + assert pad_dims in [None, 1, 2, 3] + self._pad_dims = pad_dims + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self.data)})' + + def __len__(self): + return len(self._data) + + @property + def data(self): + return self._data + + @property + def datatype(self): + if isinstance(self.data, torch.Tensor): + return self.data.type() + else: + return type(self.data) + + @property + def cpu_only(self): + return self._cpu_only + + @property + def stack(self): + return self._stack + + @property + def padding_value(self): + return self._padding_value + + @property + def pad_dims(self): + return self._pad_dims + + @assert_tensor_type + def size(self, *args, **kwargs): + return self.data.size(*args, **kwargs) + + @assert_tensor_type + def dim(self): + return self.data.dim() diff --git a/annotator/uniformer/mmcv/parallel/data_parallel.py b/annotator/uniformer/mmcv/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/data_parallel.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import chain + +from torch.nn.parallel import DataParallel + +from .scatter_gather import scatter_kwargs + + +class MMDataParallel(DataParallel): + """The DataParallel module that supports DataContainer. + + MMDataParallel has two main differences with PyTorch DataParallel: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data during both GPU and CPU inference. + - It implement two more APIs ``train_step()`` and ``val_step()``. + + Args: + module (:class:`nn.Module`): Module to be encapsulated. + device_ids (list[int]): Device IDS of modules to be scattered to. + Defaults to None when GPU is not available. + output_device (str | int): Device ID for output. Defaults to None. + dim (int): Dimension used to scatter the data. Defaults to 0. + """ + + def __init__(self, *args, dim=0, **kwargs): + super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs) + self.dim = dim + + def forward(self, *inputs, **kwargs): + """Override the original forward function. + + The main difference lies in the CPU inference where the data in + :class:`DataContainers` will still be gathered. + """ + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module(*inputs[0], **kwargs[0]) + else: + return super().forward(*inputs, **kwargs) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module.train_step(*inputs[0], **kwargs[0]) + + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + 'instead.') + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return self.module.train_step(*inputs[0], **kwargs[0]) + + def val_step(self, *inputs, **kwargs): + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module.val_step(*inputs[0], **kwargs[0]) + + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + ' instead.') + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return self.module.val_step(*inputs[0], **kwargs[0]) diff --git a/annotator/uniformer/mmcv/parallel/distributed.py b/annotator/uniformer/mmcv/parallel/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4c27903db58a54d37ea1ed9ec0104098b486f2 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/distributed.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel.distributed import (DistributedDataParallel, + _find_tensors) + +from annotator.uniformer.mmcv import print_log +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from .scatter_gather import scatter_kwargs + + +class MMDistributedDataParallel(DistributedDataParallel): + """The DDP module that supports DataContainer. + + MMDDP has two main differences with PyTorch DDP: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data. + - It implement two APIs ``train_step()`` and ``val_step()``. + """ + + def to_kwargs(self, inputs, kwargs, device_id): + # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 + # to move all tensors to device_id + return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + """train_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.train_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if getattr(self, 'require_forward_param_sync', True): + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.train_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.train_step(*inputs, **kwargs) + + if torch.is_grad_enabled() and getattr( + self, 'require_backward_grad_sync', True): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def val_step(self, *inputs, **kwargs): + """val_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.val_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if getattr(self, 'require_forward_param_sync', True): + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.val_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.val_step(*inputs, **kwargs) + + if torch.is_grad_enabled() and getattr( + self, 'require_backward_grad_sync', True): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output diff --git a/annotator/uniformer/mmcv/parallel/distributed_deprecated.py b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..676937a2085d4da20fa87923041a200fca6214eb --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from .registry import MODULE_WRAPPERS +from .scatter_gather import scatter_kwargs + + +@MODULE_WRAPPERS.register_module() +class MMDistributedDataParallel(nn.Module): + + def __init__(self, + module, + dim=0, + broadcast_buffers=True, + bucket_cap_mb=25): + super(MMDistributedDataParallel, self).__init__() + self.module = module + self.dim = dim + self.broadcast_buffers = broadcast_buffers + + self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024 + self._sync_params() + + def _dist_broadcast_coalesced(self, tensors, buffer_size): + for tensors in _take_tensors(tensors, buffer_size): + flat_tensors = _flatten_dense_tensors(tensors) + dist.broadcast(flat_tensors, 0) + for tensor, synced in zip( + tensors, _unflatten_dense_tensors(flat_tensors, tensors)): + tensor.copy_(synced) + + def _sync_params(self): + module_states = list(self.module.state_dict().values()) + if len(module_states) > 0: + self._dist_broadcast_coalesced(module_states, + self.broadcast_bucket_size) + if self.broadcast_buffers: + if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) < digit_version('1.0')): + buffers = [b.data for b in self.module._all_buffers()] + else: + buffers = [b.data for b in self.module.buffers()] + if len(buffers) > 0: + self._dist_broadcast_coalesced(buffers, + self.broadcast_bucket_size) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def forward(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + return self.module(*inputs[0], **kwargs[0]) + + def train_step(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.train_step(*inputs[0], **kwargs[0]) + return output + + def val_step(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.val_step(*inputs[0], **kwargs[0]) + return output diff --git a/annotator/uniformer/mmcv/parallel/registry.py b/annotator/uniformer/mmcv/parallel/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..a204a07fba10e614223f090d1a57cf9c4d74d4a1 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/registry.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from annotator.uniformer.mmcv.utils import Registry + +MODULE_WRAPPERS = Registry('module wrapper') +MODULE_WRAPPERS.register_module(module=DataParallel) +MODULE_WRAPPERS.register_module(module=DistributedDataParallel) diff --git a/annotator/uniformer/mmcv/parallel/scatter_gather.py b/annotator/uniformer/mmcv/parallel/scatter_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/scatter_gather.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel._functions import Scatter as OrigScatter + +from ._functions import Scatter +from .data_container import DataContainer + + +def scatter(inputs, target_gpus, dim=0): + """Scatter inputs to target gpus. + + The only difference from original :func:`scatter` is to add support for + :type:`~mmcv.parallel.DataContainer`. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + if target_gpus != [-1]: + return OrigScatter.apply(target_gpus, None, dim, obj) + else: + # for CPU inference we use self-implemented scatter + return Scatter.forward(target_gpus, obj) + if isinstance(obj, DataContainer): + if obj.cpu_only: + return obj.data + else: + return Scatter.forward(target_gpus, obj.data) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + out = list(map(list, zip(*map(scatter_map, obj)))) + return out + if isinstance(obj, dict) and len(obj) > 0: + out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return out + return [obj for targets in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): + """Scatter with support for kwargs dictionary.""" + inputs = scatter(inputs, target_gpus, dim) if inputs else [] + kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/annotator/uniformer/mmcv/parallel/utils.py b/annotator/uniformer/mmcv/parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import MODULE_WRAPPERS + + +def is_module_wrapper(module): + """Check if a module is a module wrapper. + + The following 3 modules in MMCV (and their subclasses) are regarded as + module wrappers: DataParallel, DistributedDataParallel, + MMDistributedDataParallel (the deprecated version). You may add you own + module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: True if the input module is a module wrapper. + """ + module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values()) + return isinstance(module, module_wrappers) diff --git a/annotator/uniformer/mmcv/runner/__init__.py b/annotator/uniformer/mmcv/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_module import BaseModule, ModuleList, Sequential +from .base_runner import BaseRunner +from .builder import RUNNERS, build_runner +from .checkpoint import (CheckpointLoader, _load_checkpoint, + _load_checkpoint_with_prefix, load_checkpoint, + load_state_dict, save_checkpoint, weights_to_cpu) +from .default_constructor import DefaultRunnerConstructor +from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, + init_dist, master_only) +from .epoch_based_runner import EpochBasedRunner, Runner +from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model +from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook, + DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook, + Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, Hook, IterTimerHook, + LoggerHook, LrUpdaterHook, MlflowLoggerHook, + NeptuneLoggerHook, OptimizerHook, PaviLoggerHook, + SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) +from .iter_based_runner import IterBasedRunner, IterLoader +from .log_buffer import LogBuffer +from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, + DefaultOptimizerConstructor, build_optimizer, + build_optimizer_constructor) +from .priority import Priority, get_priority +from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed + +__all__ = [ + 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer', + 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', + 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', + 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', + 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook', + 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict', + 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority', + 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict', + 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS', + 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer', + 'build_optimizer_constructor', 'IterLoader', 'set_random_seed', + 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', + 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', + 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', + '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', + 'ModuleList', 'GradientCumulativeOptimizerHook', + 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor' +] diff --git a/annotator/uniformer/mmcv/runner/base_module.py b/annotator/uniformer/mmcv/runner/base_module.py new file mode 100644 index 0000000000000000000000000000000000000000..617fad9bb89f10a9a0911d962dfb3bc8f3a3628c --- /dev/null +++ b/annotator/uniformer/mmcv/runner/base_module.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from abc import ABCMeta +from collections import defaultdict +from logging import FileHandler + +import torch.nn as nn + +from annotator.uniformer.mmcv.runner.dist_utils import master_only +from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log + + +class BaseModule(nn.Module, metaclass=ABCMeta): + """Base module for all modules in openmmlab. + + ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional + functionality of parameter initialization. Compared with + ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. + + - ``init_cfg``: the config to control the initialization. + - ``init_weights``: The function of parameter + initialization and recording initialization + information. + - ``_params_init_info``: Used to track the parameter + initialization information. This attribute only + exists during executing the ``init_weights``. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, init_cfg=None): + """Initialize BaseModule, inherited from `torch.nn.Module`""" + + # NOTE init_cfg can be defined in different levels, but init_cfg + # in low levels has a higher priority. + + super(BaseModule, self).__init__() + # define default value of init_cfg instead of hard code + # in init_weights() function + self._is_init = False + + self.init_cfg = copy.deepcopy(init_cfg) + + # Backward compatibility in derived classes + # if pretrained is not None: + # warnings.warn('DeprecationWarning: pretrained is a deprecated \ + # key, please consider using init_cfg') + # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + @property + def is_init(self): + return self._is_init + + def init_weights(self): + """Initialize the weights.""" + + is_top_level_module = False + # check if it is top-level module + if not hasattr(self, '_params_init_info'): + # The `_params_init_info` is used to record the initialization + # information of the parameters + # the key should be the obj:`nn.Parameter` of model and the value + # should be a dict containing + # - init_info (str): The string that describes the initialization. + # - tmp_mean_value (FloatTensor): The mean of the parameter, + # which indicates whether the parameter has been modified. + # this attribute would be deleted after all parameters + # is initialized. + self._params_init_info = defaultdict(dict) + is_top_level_module = True + + # Initialize the `_params_init_info`, + # When detecting the `tmp_mean_value` of + # the corresponding parameter is changed, update related + # initialization information + for name, param in self.named_parameters(): + self._params_init_info[param][ + 'init_info'] = f'The value is the same before and ' \ + f'after calling `init_weights` ' \ + f'of {self.__class__.__name__} ' + self._params_init_info[param][ + 'tmp_mean_value'] = param.data.mean() + + # pass `params_init_info` to all submodules + # All submodules share the same `params_init_info`, + # so it will be updated when parameters are + # modified at any level of the model. + for sub_module in self.modules(): + sub_module._params_init_info = self._params_init_info + + # Get the initialized logger, if not exist, + # create a logger named `mmcv` + logger_names = list(logger_initialized.keys()) + logger_name = logger_names[0] if logger_names else 'mmcv' + + from ..cnn import initialize + from ..cnn.utils.weight_init import update_init_info + module_name = self.__class__.__name__ + if not self._is_init: + if self.init_cfg: + print_log( + f'initialize {module_name} with init_cfg {self.init_cfg}', + logger=logger_name) + initialize(self, self.init_cfg) + if isinstance(self.init_cfg, dict): + # prevent the parameters of + # the pre-trained model + # from being overwritten by + # the `init_weights` + if self.init_cfg['type'] == 'Pretrained': + return + + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights() + # users may overload the `init_weights` + update_init_info( + m, + init_info=f'Initialized by ' + f'user-defined `init_weights`' + f' in {m.__class__.__name__} ') + + self._is_init = True + else: + warnings.warn(f'init_weights of {self.__class__.__name__} has ' + f'been called more than once.') + + if is_top_level_module: + self._dump_init_info(logger_name) + + for sub_module in self.modules(): + del sub_module._params_init_info + + @master_only + def _dump_init_info(self, logger_name): + """Dump the initialization information to a file named + `initialization.log.json` in workdir. + + Args: + logger_name (str): The name of logger. + """ + + logger = get_logger(logger_name) + + with_file_handler = False + # dump the information to the logger file if there is a `FileHandler` + for handler in logger.handlers: + if isinstance(handler, FileHandler): + handler.stream.write( + 'Name of parameter - Initialization information\n') + for name, param in self.named_parameters(): + handler.stream.write( + f'\n{name} - {param.shape}: ' + f"\n{self._params_init_info[param]['init_info']} \n") + handler.stream.flush() + with_file_handler = True + if not with_file_handler: + for name, param in self.named_parameters(): + print_log( + f'\n{name} - {param.shape}: ' + f"\n{self._params_init_info[param]['init_info']} \n ", + logger=logger_name) + + def __repr__(self): + s = super().__repr__() + if self.init_cfg: + s += f'\ninit_cfg={self.init_cfg}' + return s + + +class Sequential(BaseModule, nn.Sequential): + """Sequential module in openmmlab. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, *args, init_cfg=None): + BaseModule.__init__(self, init_cfg) + nn.Sequential.__init__(self, *args) + + +class ModuleList(BaseModule, nn.ModuleList): + """ModuleList in openmmlab. + + Args: + modules (iterable, optional): an iterable of modules to add. + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, modules=None, init_cfg=None): + BaseModule.__init__(self, init_cfg) + nn.ModuleList.__init__(self, modules) diff --git a/annotator/uniformer/mmcv/runner/base_runner.py b/annotator/uniformer/mmcv/runner/base_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4928db0a73b56fe0218a4bf66ec4ffa082d31ccc --- /dev/null +++ b/annotator/uniformer/mmcv/runner/base_runner.py @@ -0,0 +1,542 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import os.path as osp +import warnings +from abc import ABCMeta, abstractmethod + +import torch +from torch.optim import Optimizer + +import annotator.uniformer.mmcv as mmcv +from ..parallel import is_module_wrapper +from .checkpoint import load_checkpoint +from .dist_utils import get_dist_info +from .hooks import HOOKS, Hook +from .log_buffer import LogBuffer +from .priority import Priority, get_priority +from .utils import get_time_str + + +class BaseRunner(metaclass=ABCMeta): + """The base class of Runner, a training helper for PyTorch. + + All subclasses should implement the following APIs: + + - ``run()`` + - ``train()`` + - ``val()`` + - ``save_checkpoint()`` + + Args: + model (:obj:`torch.nn.Module`): The model to be run. + batch_processor (callable): A callable method that process a data + batch. The interface of this method should be + `batch_processor(model, data, train_mode) -> dict` + optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an + optimizer (in most cases) or a dict of optimizers (in models that + requires more than one optimizer, e.g., GAN). + work_dir (str, optional): The working directory to save checkpoints + and logs. Defaults to None. + logger (:obj:`logging.Logger`): Logger used during training. + Defaults to None. (The default value is just for backward + compatibility) + meta (dict | None): A dict records some import information such as + environment info and seed, which will be logged in logger hook. + Defaults to None. + max_epochs (int, optional): Total training epochs. + max_iters (int, optional): Total training iterations. + """ + + def __init__(self, + model, + batch_processor=None, + optimizer=None, + work_dir=None, + logger=None, + meta=None, + max_iters=None, + max_epochs=None): + if batch_processor is not None: + if not callable(batch_processor): + raise TypeError('batch_processor must be callable, ' + f'but got {type(batch_processor)}') + warnings.warn('batch_processor is deprecated, please implement ' + 'train_step() and val_step() in the model instead.') + # raise an error is `batch_processor` is not None and + # `model.train_step()` exists. + if is_module_wrapper(model): + _model = model.module + else: + _model = model + if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): + raise RuntimeError( + 'batch_processor and model.train_step()/model.val_step() ' + 'cannot be both available.') + else: + assert hasattr(model, 'train_step') + + # check the type of `optimizer` + if isinstance(optimizer, dict): + for name, optim in optimizer.items(): + if not isinstance(optim, Optimizer): + raise TypeError( + f'optimizer must be a dict of torch.optim.Optimizers, ' + f'but optimizer["{name}"] is a {type(optim)}') + elif not isinstance(optimizer, Optimizer) and optimizer is not None: + raise TypeError( + f'optimizer must be a torch.optim.Optimizer object ' + f'or dict or None, but got {type(optimizer)}') + + # check the type of `logger` + if not isinstance(logger, logging.Logger): + raise TypeError(f'logger must be a logging.Logger object, ' + f'but got {type(logger)}') + + # check the type of `meta` + if meta is not None and not isinstance(meta, dict): + raise TypeError( + f'meta must be a dict or None, but got {type(meta)}') + + self.model = model + self.batch_processor = batch_processor + self.optimizer = optimizer + self.logger = logger + self.meta = meta + # create work_dir + if mmcv.is_str(work_dir): + self.work_dir = osp.abspath(work_dir) + mmcv.mkdir_or_exist(self.work_dir) + elif work_dir is None: + self.work_dir = None + else: + raise TypeError('"work_dir" must be a str or None') + + # get model name from the model class + if hasattr(self.model, 'module'): + self._model_name = self.model.module.__class__.__name__ + else: + self._model_name = self.model.__class__.__name__ + + self._rank, self._world_size = get_dist_info() + self.timestamp = get_time_str() + self.mode = None + self._hooks = [] + self._epoch = 0 + self._iter = 0 + self._inner_iter = 0 + + if max_epochs is not None and max_iters is not None: + raise ValueError( + 'Only one of `max_epochs` or `max_iters` can be set.') + + self._max_epochs = max_epochs + self._max_iters = max_iters + # TODO: Redesign LogBuffer, it is not flexible and elegant enough + self.log_buffer = LogBuffer() + + @property + def model_name(self): + """str: Name of the model, usually the module class name.""" + return self._model_name + + @property + def rank(self): + """int: Rank of current process. (distributed training)""" + return self._rank + + @property + def world_size(self): + """int: Number of processes participating in the job. + (distributed training)""" + return self._world_size + + @property + def hooks(self): + """list[:obj:`Hook`]: A list of registered hooks.""" + return self._hooks + + @property + def epoch(self): + """int: Current epoch.""" + return self._epoch + + @property + def iter(self): + """int: Current iteration.""" + return self._iter + + @property + def inner_iter(self): + """int: Iteration in an epoch.""" + return self._inner_iter + + @property + def max_epochs(self): + """int: Maximum training epochs.""" + return self._max_epochs + + @property + def max_iters(self): + """int: Maximum training iterations.""" + return self._max_iters + + @abstractmethod + def train(self): + pass + + @abstractmethod + def val(self): + pass + + @abstractmethod + def run(self, data_loaders, workflow, **kwargs): + pass + + @abstractmethod + def save_checkpoint(self, + out_dir, + filename_tmpl, + save_optimizer=True, + meta=None, + create_symlink=True): + pass + + def current_lr(self): + """Get current learning rates. + + Returns: + list[float] | dict[str, list[float]]: Current learning rates of all + param groups. If the runner has a dict of optimizers, this + method will return a dict. + """ + if isinstance(self.optimizer, torch.optim.Optimizer): + lr = [group['lr'] for group in self.optimizer.param_groups] + elif isinstance(self.optimizer, dict): + lr = dict() + for name, optim in self.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + + def current_momentum(self): + """Get current momentums. + + Returns: + list[float] | dict[str, list[float]]: Current momentums of all + param groups. If the runner has a dict of optimizers, this + method will return a dict. + """ + + def _get_momentum(optimizer): + momentums = [] + for group in optimizer.param_groups: + if 'momentum' in group.keys(): + momentums.append(group['momentum']) + elif 'betas' in group.keys(): + momentums.append(group['betas'][0]) + else: + momentums.append(0) + return momentums + + if self.optimizer is None: + raise RuntimeError( + 'momentum is not applicable because optimizer does not exist.') + elif isinstance(self.optimizer, torch.optim.Optimizer): + momentums = _get_momentum(self.optimizer) + elif isinstance(self.optimizer, dict): + momentums = dict() + for name, optim in self.optimizer.items(): + momentums[name] = _get_momentum(optim) + return momentums + + def register_hook(self, hook, priority='NORMAL'): + """Register a hook into the hook list. + + The hook will be inserted into a priority queue, with the specified + priority (See :class:`Priority` for details of priorities). + For hooks with the same priority, they will be triggered in the same + order as they are registered. + + Args: + hook (:obj:`Hook`): The hook to be registered. + priority (int or str or :obj:`Priority`): Hook priority. + Lower value means higher priority. + """ + assert isinstance(hook, Hook) + if hasattr(hook, 'priority'): + raise ValueError('"priority" is a reserved attribute for hooks') + priority = get_priority(priority) + hook.priority = priority + # insert the hook to a sorted list + inserted = False + for i in range(len(self._hooks) - 1, -1, -1): + if priority >= self._hooks[i].priority: + self._hooks.insert(i + 1, hook) + inserted = True + break + if not inserted: + self._hooks.insert(0, hook) + + def register_hook_from_cfg(self, hook_cfg): + """Register a hook from its cfg. + + Args: + hook_cfg (dict): Hook config. It should have at least keys 'type' + and 'priority' indicating its type and priority. + + Notes: + The specific hook class to register should not use 'type' and + 'priority' arguments during initialization. + """ + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = mmcv.build_from_cfg(hook_cfg, HOOKS) + self.register_hook(hook, priority=priority) + + def call_hook(self, fn_name): + """Call all hooks. + + Args: + fn_name (str): The function name in each hook to be called, such as + "before_train_epoch". + """ + for hook in self._hooks: + getattr(hook, fn_name)(self) + + def get_hook_info(self): + # Get hooks info in each stage + stage_hook_map = {stage: [] for stage in Hook.stages} + for hook in self.hooks: + try: + priority = Priority(hook.priority).name + except ValueError: + priority = hook.priority + classname = hook.__class__.__name__ + hook_info = f'({priority:<12}) {classname:<35}' + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) + + stage_hook_infos = [] + for stage in Hook.stages: + hook_infos = stage_hook_map[stage] + if len(hook_infos) > 0: + info = f'{stage}:\n' + info += '\n'.join(hook_infos) + info += '\n -------------------- ' + stage_hook_infos.append(info) + return '\n'.join(stage_hook_infos) + + def load_checkpoint(self, + filename, + map_location='cpu', + strict=False, + revise_keys=[(r'^module.', '')]): + return load_checkpoint( + self.model, + filename, + map_location, + strict, + self.logger, + revise_keys=revise_keys) + + def resume(self, + checkpoint, + resume_optimizer=True, + map_location='default'): + if map_location == 'default': + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint(checkpoint) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + if self.meta is None: + self.meta = {} + self.meta.setdefault('hook_msgs', {}) + # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages + self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) + + # Re-calculate the number of iterations when resuming + # models with different number of GPUs + if 'config' in checkpoint['meta']: + config = mmcv.Config.fromstring( + checkpoint['meta']['config'], file_format='.py') + previous_gpu_ids = config.get('gpu_ids', None) + if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( + previous_gpu_ids) != self.world_size: + self._iter = int(self._iter * len(previous_gpu_ids) / + self.world_size) + self.logger.info('the iteration number is changed due to ' + 'change of GPU number') + + # resume meta information meta + self.meta = checkpoint['meta'] + + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) + + def register_lr_hook(self, lr_config): + if lr_config is None: + return + elif isinstance(lr_config, dict): + assert 'policy' in lr_config + policy_type = lr_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of Lr updater. + # Since this is not applicable for ` + # CosineAnnealingLrUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'LrUpdaterHook' + lr_config['type'] = hook_type + hook = mmcv.build_from_cfg(lr_config, HOOKS) + else: + hook = lr_config + self.register_hook(hook, priority='VERY_HIGH') + + def register_momentum_hook(self, momentum_config): + if momentum_config is None: + return + if isinstance(momentum_config, dict): + assert 'policy' in momentum_config + policy_type = momentum_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of momentum updater. + # Since this is not applicable for + # `CosineAnnealingMomentumUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'MomentumUpdaterHook' + momentum_config['type'] = hook_type + hook = mmcv.build_from_cfg(momentum_config, HOOKS) + else: + hook = momentum_config + self.register_hook(hook, priority='HIGH') + + def register_optimizer_hook(self, optimizer_config): + if optimizer_config is None: + return + if isinstance(optimizer_config, dict): + optimizer_config.setdefault('type', 'OptimizerHook') + hook = mmcv.build_from_cfg(optimizer_config, HOOKS) + else: + hook = optimizer_config + self.register_hook(hook, priority='ABOVE_NORMAL') + + def register_checkpoint_hook(self, checkpoint_config): + if checkpoint_config is None: + return + if isinstance(checkpoint_config, dict): + checkpoint_config.setdefault('type', 'CheckpointHook') + hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) + else: + hook = checkpoint_config + self.register_hook(hook, priority='NORMAL') + + def register_logger_hooks(self, log_config): + if log_config is None: + return + log_interval = log_config['interval'] + for info in log_config['hooks']: + logger_hook = mmcv.build_from_cfg( + info, HOOKS, default_args=dict(interval=log_interval)) + self.register_hook(logger_hook, priority='VERY_LOW') + + def register_timer_hook(self, timer_config): + if timer_config is None: + return + if isinstance(timer_config, dict): + timer_config_ = copy.deepcopy(timer_config) + hook = mmcv.build_from_cfg(timer_config_, HOOKS) + else: + hook = timer_config + self.register_hook(hook, priority='LOW') + + def register_custom_hooks(self, custom_config): + if custom_config is None: + return + + if not isinstance(custom_config, list): + custom_config = [custom_config] + + for item in custom_config: + if isinstance(item, dict): + self.register_hook_from_cfg(item) + else: + self.register_hook(item, priority='NORMAL') + + def register_profiler_hook(self, profiler_config): + if profiler_config is None: + return + if isinstance(profiler_config, dict): + profiler_config.setdefault('type', 'ProfilerHook') + hook = mmcv.build_from_cfg(profiler_config, HOOKS) + else: + hook = profiler_config + self.register_hook(hook) + + def register_training_hooks(self, + lr_config, + optimizer_config=None, + checkpoint_config=None, + log_config=None, + momentum_config=None, + timer_config=dict(type='IterTimerHook'), + custom_hooks_config=None): + """Register default and custom hooks for training. + + Default and custom hooks include: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. + """ + self.register_lr_hook(lr_config) + self.register_momentum_hook(momentum_config) + self.register_optimizer_hook(optimizer_config) + self.register_checkpoint_hook(checkpoint_config) + self.register_timer_hook(timer_config) + self.register_logger_hooks(log_config) + self.register_custom_hooks(custom_hooks_config) diff --git a/annotator/uniformer/mmcv/runner/builder.py b/annotator/uniformer/mmcv/runner/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f --- /dev/null +++ b/annotator/uniformer/mmcv/runner/builder.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from ..utils import Registry + +RUNNERS = Registry('runner') +RUNNER_BUILDERS = Registry('runner builder') + + +def build_runner_constructor(cfg): + return RUNNER_BUILDERS.build(cfg) + + +def build_runner(cfg, default_args=None): + runner_cfg = copy.deepcopy(cfg) + constructor_type = runner_cfg.pop('constructor', + 'DefaultRunnerConstructor') + runner_constructor = build_runner_constructor( + dict( + type=constructor_type, + runner_cfg=runner_cfg, + default_args=default_args)) + runner = runner_constructor() + return runner diff --git a/annotator/uniformer/mmcv/runner/checkpoint.py b/annotator/uniformer/mmcv/runner/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b29ca320679164432f446adad893e33fb2b4b29e --- /dev/null +++ b/annotator/uniformer/mmcv/runner/checkpoint.py @@ -0,0 +1,707 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import re +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo + +import annotator.uniformer.mmcv as mmcv +from ..fileio import FileClient +from ..fileio import load as load_file +from ..parallel import is_module_wrapper +from ..utils import mkdir_or_exist +from .dist_utils import get_dist_info + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +class CheckpointLoader: + """A general checkpoint loader to manage all schemes.""" + + _schemes = {} + + @classmethod + def _register_scheme(cls, prefixes, loader, force=False): + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._schemes) or force: + cls._schemes[prefix] = loader + else: + raise KeyError( + f'{prefix} is already registered as a loader backend, ' + 'add "force=True" if you want to override it') + # sort, longer prefixes take priority + cls._schemes = OrderedDict( + sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) + + @classmethod + def register_scheme(cls, prefixes, loader=None, force=False): + """Register a loader to CheckpointLoader. + + This method can be used as a normal class method or a decorator. + + Args: + prefixes (str or list[str] or tuple[str]): + The prefix of the registered loader. + loader (function, optional): The loader function to be registered. + When this method is used as a decorator, loader is None. + Defaults to None. + force (bool, optional): Whether to override the loader + if the prefix has already been registered. Defaults to False. + """ + + if loader is not None: + cls._register_scheme(prefixes, loader, force=force) + return + + def _register(loader_cls): + cls._register_scheme(prefixes, loader_cls, force=force) + return loader_cls + + return _register + + @classmethod + def _get_checkpoint_loader(cls, path): + """Finds a loader that supports the given path. Falls back to the local + loader if no other loader is found. + + Args: + path (str): checkpoint path + + Returns: + loader (function): checkpoint loader + """ + + for p in cls._schemes: + if path.startswith(p): + return cls._schemes[p] + + @classmethod + def load_checkpoint(cls, filename, map_location=None, logger=None): + """load checkpoint through URL scheme path. + + Args: + filename (str): checkpoint file name with given prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + logger (:mod:`logging.Logger`, optional): The logger for message. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint_loader = cls._get_checkpoint_loader(filename) + class_name = checkpoint_loader.__name__ + mmcv.print_log( + f'load checkpoint from {class_name[10:]} path: {filename}', logger) + return checkpoint_loader(filename, map_location) + + +@CheckpointLoader.register_scheme(prefixes='') +def load_from_local(filename, map_location): + """load checkpoint by local file path. + + Args: + filename (str): local checkpoint file path + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) +def load_from_http(filename, map_location=None, model_dir=None): + """load checkpoint through HTTP or HTTPS scheme path. In distributed + setting, this function only download checkpoint at local rank 0. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + model_dir (string, optional): directory in which to save the object, + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url( + filename, model_dir=model_dir, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url( + filename, model_dir=model_dir, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='pavi://') +def load_from_pavi(filename, map_location=None): + """load checkpoint through the file path prefixed with pavi. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with pavi prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + assert filename.startswith('pavi://'), \ + f'Expected filename startswith `pavi://`, but get {filename}' + model_path = filename[7:] + + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='s3://') +def load_from_ceph(filename, map_location=None, backend='petrel'): + """load checkpoint through the file path prefixed with s3. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with s3 prefix + map_location (str, optional): Same as :func:`torch.load`. + backend (str, optional): The storage backend type. Options are 'ceph', + 'petrel'. Default: 'petrel'. + + .. warning:: + :class:`mmcv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`mmcv.fileio.file_client.PetrelBackend` instead. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + allowed_backends = ['ceph', 'petrel'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + + if backend == 'ceph': + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') + + # CephClient and PetrelBackend have the same prefix 's3://' and the latter + # will be chosen as default. If PetrelBackend can not be instantiated + # successfully, the CephClient will be chosen. + try: + file_client = FileClient(backend=backend) + except ImportError: + allowed_backends.remove(backend) + file_client = FileClient(backend=allowed_backends[0]) + + with io.BytesIO(file_client.get(filename)) as buffer: + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) +def load_from_torchvision(filename, map_location=None): + """load checkpoint through the file path prefixed with modelzoo or + torchvision. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + model_urls = get_torchvision_models() + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_name = filename[11:] + else: + model_name = filename[14:] + return load_from_http(model_urls[model_name], map_location=map_location) + + +@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) +def load_from_openmmlab(filename, map_location=None): + """load checkpoint through the file path prefixed with open-mmlab or + openmmlab. + + Args: + filename (str): checkpoint file path with open-mmlab or + openmmlab prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_external_models() + prefix_str = 'open-mmlab://' + if filename.startswith(prefix_str): + model_name = filename[13:] + else: + model_name = filename[12:] + prefix_str = 'openmmlab://' + + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'{prefix_str}{model_name} is deprecated in favor ' + f'of {prefix_str}{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_from_http(model_url, map_location=map_location) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='mmcls://') +def load_from_mmcls(filename, map_location=None): + """load checkpoint through the file path prefixed with mmcls. + + Args: + filename (str): checkpoint file path with mmcls prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_from_http( + model_urls[model_name], map_location=map_location) + checkpoint = _process_mmcls_checkpoint(checkpoint) + return checkpoint + + +def _load_checkpoint(filename, map_location=None, logger=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str, optional): Same as :func:`torch.load`. + Default: None. + logger (:mod:`logging.Logger`, optional): The logger for error message. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + return CheckpointLoader.load_checkpoint(filename, map_location, logger) + + +def _load_checkpoint_with_prefix(prefix, filename, map_location=None): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = _load_checkpoint(filename, map_location=map_location) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict + + +def load_checkpoint(model, + filename, + map_location=None, + strict=False, + logger=None, + revise_keys=[(r'^module\.', '')]): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Default: strip + the prefix 'module.' by [(r'^module\\.', '')]. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location, logger) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # strip prefix of state_dict + metadata = getattr(state_dict, '_metadata', OrderedDict()) + for p, r in revise_keys: + state_dict = OrderedDict( + {re.sub(p, r, k): v + for k, v in state_dict.items()}) + # Keep metadata in state_dict + state_dict._metadata = metadata + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + # Keep metadata in state_dict + state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict()) + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, + filename, + optimizer=None, + meta=None, + file_client_args=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + if file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" if filename starts with' + f'"pavi://", but got {file_client_args}') + try: + from pavi import modelcloud + from pavi import exception + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except exception.NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + file_client = FileClient.infer_client(file_client_args, filename) + with io.BytesIO() as f: + torch.save(checkpoint, f) + file_client.put(f.getvalue(), filename) diff --git a/annotator/uniformer/mmcv/runner/default_constructor.py b/annotator/uniformer/mmcv/runner/default_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1f5b44168768dfda3947393a63a6cf9cf50b41 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/default_constructor.py @@ -0,0 +1,44 @@ +from .builder import RUNNER_BUILDERS, RUNNERS + + +@RUNNER_BUILDERS.register_module() +class DefaultRunnerConstructor: + """Default constructor for runners. + + Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`. + For example, We can inject some new properties and functions for `Runner`. + + Example: + >>> from annotator.uniformer.mmcv.runner import RUNNER_BUILDERS, build_runner + >>> # Define a new RunnerReconstructor + >>> @RUNNER_BUILDERS.register_module() + >>> class MyRunnerConstructor: + ... def __init__(self, runner_cfg, default_args=None): + ... if not isinstance(runner_cfg, dict): + ... raise TypeError('runner_cfg should be a dict', + ... f'but got {type(runner_cfg)}') + ... self.runner_cfg = runner_cfg + ... self.default_args = default_args + ... + ... def __call__(self): + ... runner = RUNNERS.build(self.runner_cfg, + ... default_args=self.default_args) + ... # Add new properties for existing runner + ... runner.my_name = 'my_runner' + ... runner.my_function = lambda self: print(self.my_name) + ... ... + >>> # build your runner + >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40, + ... constructor='MyRunnerConstructor') + >>> runner = build_runner(runner_cfg) + """ + + def __init__(self, runner_cfg, default_args=None): + if not isinstance(runner_cfg, dict): + raise TypeError('runner_cfg should be a dict', + f'but got {type(runner_cfg)}') + self.runner_cfg = runner_cfg + self.default_args = default_args + + def __call__(self): + return RUNNERS.build(self.runner_cfg, default_args=self.default_args) diff --git a/annotator/uniformer/mmcv/runner/dist_utils.py b/annotator/uniformer/mmcv/runner/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe --- /dev/null +++ b/annotator/uniformer/mmcv/runner/dist_utils.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import os +import subprocess +from collections import OrderedDict + +import torch +import torch.multiprocessing as mp +from torch import distributed as dist +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def allreduce_params(params, coalesce=True, bucket_size_mb=-1): + """Allreduce parameters. + + Args: + params (list[torch.Parameters]): List of parameters or buffers of a + model. + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + _, world_size = get_dist_info() + if world_size == 1: + return + params = [param.data for param in params] + if coalesce: + _allreduce_coalesced(params, world_size, bucket_size_mb) + else: + for tensor in params: + dist.all_reduce(tensor.div_(world_size)) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + """Allreduce gradients. + + Args: + params (list[torch.Parameters]): List of parameters of a model + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + _, world_size = get_dist_info() + if world_size == 1: + return + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) diff --git a/annotator/uniformer/mmcv/runner/epoch_based_runner.py b/annotator/uniformer/mmcv/runner/epoch_based_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..766a9ce6afdf09cd11b1b15005f5132583011348 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/epoch_based_runner.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import shutil +import time +import warnings + +import torch + +import annotator.uniformer.mmcv as mmcv +from .base_runner import BaseRunner +from .builder import RUNNERS +from .checkpoint import save_checkpoint +from .utils import get_host_info + + +@RUNNERS.register_module() +class EpochBasedRunner(BaseRunner): + """Epoch-based Runner. + + This runner train models epoch by epoch. + """ + + def run_iter(self, data_batch, train_mode, **kwargs): + if self.batch_processor is not None: + outputs = self.batch_processor( + self.model, data_batch, train_mode=train_mode, **kwargs) + elif train_mode: + outputs = self.model.train_step(data_batch, self.optimizer, + **kwargs) + else: + outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('"batch_processor()" or "model.train_step()"' + 'and "model.val_step()" must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._max_iters = self._max_epochs * len(self.data_loader) + self.call_hook('before_train_epoch') + time.sleep(2) # Prevent possible deadlock during epoch transition + for i, data_batch in enumerate(self.data_loader): + self._inner_iter = i + self.call_hook('before_train_iter') + self.run_iter(data_batch, train_mode=True, **kwargs) + self.call_hook('after_train_iter') + self._iter += 1 + + self.call_hook('after_train_epoch') + self._epoch += 1 + + @torch.no_grad() + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + self.call_hook('before_val_epoch') + time.sleep(2) # Prevent possible deadlock during epoch transition + for i, data_batch in enumerate(self.data_loader): + self._inner_iter = i + self.call_hook('before_val_iter') + self.run_iter(data_batch, train_mode=False) + self.call_hook('after_val_iter') + + self.call_hook('after_val_epoch') + + def run(self, data_loaders, workflow, max_epochs=None, **kwargs): + """Start running. + + Args: + data_loaders (list[:obj:`DataLoader`]): Dataloaders for training + and validation. + workflow (list[tuple]): A list of (phase, epochs) to specify the + running order and epochs. E.g, [('train', 2), ('val', 1)] means + running 2 epochs for training and 1 epoch for validation, + iteratively. + """ + assert isinstance(data_loaders, list) + assert mmcv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + if max_epochs is not None: + warnings.warn( + 'setting max_epochs in run is deprecated, ' + 'please set max_epochs in runner_config', DeprecationWarning) + self._max_epochs = max_epochs + + assert self._max_epochs is not None, ( + 'max_epochs must be specified during instantiation') + + for i, flow in enumerate(workflow): + mode, epochs = flow + if mode == 'train': + self._max_iters = self._max_epochs * len(data_loaders[i]) + break + + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) + self.logger.info('workflow: %s, max: %d epochs', workflow, + self._max_epochs) + self.call_hook('before_run') + + while self.epoch < self._max_epochs: + for i, flow in enumerate(workflow): + mode, epochs = flow + if isinstance(mode, str): # self.train() + if not hasattr(self, mode): + raise ValueError( + f'runner has no method named "{mode}" to run an ' + 'epoch') + epoch_runner = getattr(self, mode) + else: + raise TypeError( + 'mode in workflow must be a str, but got {}'.format( + type(mode))) + + for _ in range(epochs): + if mode == 'train' and self.epoch >= self._max_epochs: + break + epoch_runner(data_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_run') + + def save_checkpoint(self, + out_dir, + filename_tmpl='epoch_{}.pth', + save_optimizer=True, + meta=None, + create_symlink=True): + """Save the checkpoint. + + Args: + out_dir (str): The directory that checkpoints are saved. + filename_tmpl (str, optional): The checkpoint filename template, + which contains a placeholder for the epoch number. + Defaults to 'epoch_{}.pth'. + save_optimizer (bool, optional): Whether to save the optimizer to + the checkpoint. Defaults to True. + meta (dict, optional): The meta information to be saved in the + checkpoint. Defaults to None. + create_symlink (bool, optional): Whether to create a symlink + "latest.pth" to point to the latest checkpoint. + Defaults to True. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + # Note: meta.update(self.meta) should be done before + # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise + # there will be problems with resumed checkpoints. + # More details in https://github.com/open-mmlab/mmcv/pull/1108 + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = filename_tmpl.format(self.epoch + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + mmcv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + +@RUNNERS.register_module() +class Runner(EpochBasedRunner): + """Deprecated name of EpochBasedRunner.""" + + def __init__(self, *args, **kwargs): + warnings.warn( + 'Runner was deprecated, please use EpochBasedRunner instead') + super().__init__(*args, **kwargs) diff --git a/annotator/uniformer/mmcv/runner/fp16_utils.py b/annotator/uniformer/mmcv/runner/fp16_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1981011d6859192e3e663e29d13500d56ba47f6c --- /dev/null +++ b/annotator/uniformer/mmcv/runner/fp16_utils.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import warnings +from collections import abc +from inspect import getfullargspec + +import numpy as np +import torch +import torch.nn as nn + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from .dist_utils import allreduce_grads as _allreduce_grads + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 + # manually, so the behavior may not be consistent with real amp. + from torch.cuda.amp import autocast +except ImportError: + pass + + +def cast_tensor_type(inputs, src_type, dst_type): + """Recursively convert Tensor in inputs from src_type to dst_type. + + Args: + inputs: Inputs that to be casted. + src_type (torch.dtype): Source type.. + dst_type (torch.dtype): Destination type. + + Returns: + The same type with inputs, but all contained Tensors have been cast. + """ + if isinstance(inputs, nn.Module): + return inputs + elif isinstance(inputs, torch.Tensor): + return inputs.to(dst_type) + elif isinstance(inputs, str): + return inputs + elif isinstance(inputs, np.ndarray): + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)({ + k: cast_tensor_type(v, src_type, dst_type) + for k, v in inputs.items() + }) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( + cast_tensor_type(item, src_type, dst_type) for item in inputs) + else: + return inputs + + +def auto_fp16(apply_to=None, out_fp32=False): + """Decorator to enable fp16 training automatically. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If inputs arguments are fp32 tensors, they will + be converted to fp16 automatically. Arguments other than fp32 tensors are + ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp32 (bool): Whether to convert the output back to fp32. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp16 + >>> @auto_fp16() + >>> def forward(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp16 + >>> @auto_fp16(apply_to=('pred', )) + >>> def do_something(self, pred, others): + >>> pass + """ + + def auto_fp16_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@auto_fp16 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.float, torch.half)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.float, torch.half) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if (TORCH_VERSION != 'parrots' and + digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + with autocast(enabled=True): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, torch.half, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def force_fp32(apply_to=None, out_fp16=False): + """Decorator to convert input arguments to fp32 in force. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If there are some inputs that must be processed + in fp32 mode, then this decorator can handle it. If inputs arguments are + fp16 tensors, they will be converted to fp32 automatically. Arguments other + than fp16 tensors are ignored. If you are using PyTorch >= 1.6, + torch.cuda.amp is used as the backend, otherwise, original mmcv + implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp16 (bool): Whether to convert the output back to fp16. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp32 + >>> @force_fp32() + >>> def loss(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp32 + >>> @force_fp32(apply_to=('pred', )) + >>> def post_process(self, pred, others): + >>> pass + """ + + def force_fp32_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@force_fp32 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.half, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.half, torch.float) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if (TORCH_VERSION != 'parrots' and + digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + with autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + warnings.warning( + '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' + 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads') + _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) + + +def wrap_fp16_model(model): + """Wrap the FP32 model to FP16. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + For PyTorch >= 1.6, this function will + 1. Set fp16 flag inside the model to True. + + Otherwise: + 1. Convert FP32 model to FP16. + 2. Remain some necessary layers to be FP32, e.g., normalization layers. + 3. Set `fp16_enabled` flag inside the model to True. + + Args: + model (nn.Module): Model in FP32. + """ + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.6.0')): + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) + # set `fp16_enabled` flag + for m in model.modules(): + if hasattr(m, 'fp16_enabled'): + m.fp16_enabled = True + + +def patch_norm_fp32(module): + """Recursively convert normalization layers from FP16 to FP32. + + Args: + module (nn.Module): The modules to be converted in FP16. + + Returns: + nn.Module: The converted module, the normalization layers have been + converted to FP32. + """ + if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + module.float() + if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': + module.forward = patch_forward_method(module.forward, torch.half, + torch.float) + for child in module.children(): + patch_norm_fp32(child) + return module + + +def patch_forward_method(func, src_type, dst_type, convert_output=True): + """Patch the forward method of a module. + + Args: + func (callable): The original forward method. + src_type (torch.dtype): Type of input arguments to be converted from. + dst_type (torch.dtype): Type of input arguments to be converted to. + convert_output (bool): Whether to convert the output back to src_type. + + Returns: + callable: The patched forward method. + """ + + def new_forward(*args, **kwargs): + output = func(*cast_tensor_type(args, src_type, dst_type), + **cast_tensor_type(kwargs, src_type, dst_type)) + if convert_output: + output = cast_tensor_type(output, dst_type, src_type) + return output + + return new_forward + + +class LossScaler: + """Class that manages loss scaling in mixed precision training which + supports both dynamic or static mode. + + The implementation refers to + https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py. + Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling. + It's important to understand how :class:`LossScaler` operates. + Loss scaling is designed to combat the problem of underflowing + gradients encountered at long times when training fp16 networks. + Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. + If overflowing gradients are encountered, :class:`FP16_Optimizer` then + skips the update step for this particular iteration/minibatch, + and :class:`LossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients + detected,:class:`LossScaler` increases the loss scale once more. + In this way :class:`LossScaler` attempts to "ride the edge" of always + using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float): Initial loss scale value, default: 2**32. + scale_factor (float): Factor used when adjusting the loss scale. + Default: 2. + mode (str): Loss scaling mode. 'dynamic' or 'static' + scale_window (int): Number of consecutive iterations without an + overflow to wait before increasing the loss scale. Default: 1000. + """ + + def __init__(self, + init_scale=2**32, + mode='dynamic', + scale_factor=2., + scale_window=1000): + self.cur_scale = init_scale + self.cur_iter = 0 + assert mode in ('dynamic', + 'static'), 'mode can only be dynamic or static' + self.mode = mode + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + def has_overflow(self, params): + """Check if params contain overflow.""" + if self.mode != 'dynamic': + return False + for p in params: + if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data): + return True + return False + + def _has_inf_or_nan(x): + """Check if params contain NaN.""" + try: + cpu_sum = float(x.float().sum()) + except RuntimeError as instance: + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') \ + or cpu_sum != cpu_sum: + return True + return False + + def update_scale(self, overflow): + """update the current loss scale value when overflow happens.""" + if self.mode != 'dynamic': + return + if overflow: + self.cur_scale = max(self.cur_scale / self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % \ + self.scale_window == 0: + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + def state_dict(self): + """Returns the state of the scaler as a :class:`dict`.""" + return dict( + cur_scale=self.cur_scale, + cur_iter=self.cur_iter, + mode=self.mode, + last_overflow_iter=self.last_overflow_iter, + scale_factor=self.scale_factor, + scale_window=self.scale_window) + + def load_state_dict(self, state_dict): + """Loads the loss_scaler state dict. + + Args: + state_dict (dict): scaler state. + """ + self.cur_scale = state_dict['cur_scale'] + self.cur_iter = state_dict['cur_iter'] + self.mode = state_dict['mode'] + self.last_overflow_iter = state_dict['last_overflow_iter'] + self.scale_factor = state_dict['scale_factor'] + self.scale_window = state_dict['scale_window'] + + @property + def loss_scale(self): + return self.cur_scale diff --git a/annotator/uniformer/mmcv/runner/hooks/__init__.py b/annotator/uniformer/mmcv/runner/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .checkpoint import CheckpointHook +from .closure import ClosureHook +from .ema import EMAHook +from .evaluation import DistEvalHook, EvalHook +from .hook import HOOKS, Hook +from .iter_timer import IterTimerHook +from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook, + NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook, + TextLoggerHook, WandbLoggerHook) +from .lr_updater import LrUpdaterHook +from .memory import EmptyCacheHook +from .momentum_updater import MomentumUpdaterHook +from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, OptimizerHook) +from .profiler import ProfilerHook +from .sampler_seed import DistSamplerSeedHook +from .sync_buffer import SyncBuffersHook + +__all__ = [ + 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', + 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', + 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', + 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', + 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook', + 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook', + 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook', + 'GradientCumulativeFp16OptimizerHook' +] diff --git a/annotator/uniformer/mmcv/runner/hooks/checkpoint.py b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6af3fae43ac4b35532641a81eb13557edfc7dfba --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings + +from annotator.uniformer.mmcv.fileio import FileClient +from ..dist_utils import allreduce_params, master_only +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class CheckpointHook(Hook): + """Save checkpoints periodically. + + Args: + interval (int): The saving period. If ``by_epoch=True``, interval + indicates epochs, otherwise it indicates iterations. + Default: -1, which means "never". + by_epoch (bool): Saving checkpoints by epoch or by iteration. + Default: True. + save_optimizer (bool): Whether to save optimizer state_dict in the + checkpoint. It is usually used for resuming experiments. + Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, ``runner.work_dir`` will be used by default. If + specified, the ``out_dir`` will be the concatenation of ``out_dir`` + and the last level directory of ``runner.work_dir``. + `Changed in version 1.3.16.` + max_keep_ckpts (int, optional): The maximum checkpoints to keep. + In some cases we want only the latest few checkpoints and would + like to delete old ones to save the disk space. + Default: -1, which means unlimited. + save_last (bool, optional): Whether to force the last checkpoint to be + saved regardless of interval. Default: True. + sync_buffer (bool, optional): Whether to synchronize buffers in + different gpus. Default: False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + + .. warning:: + Before v1.3.16, the ``out_dir`` argument indicates the path where the + checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the + root directory and the final path to save checkpoint is the + concatenation of ``out_dir`` and the last level directory of + ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" + and the value of ``runner.work_dir`` is "/path/of/B", then the final + path will be "/path/of/A/B". + """ + + def __init__(self, + interval=-1, + by_epoch=True, + save_optimizer=True, + out_dir=None, + max_keep_ckpts=-1, + save_last=True, + sync_buffer=False, + file_client_args=None, + **kwargs): + self.interval = interval + self.by_epoch = by_epoch + self.save_optimizer = save_optimizer + self.out_dir = out_dir + self.max_keep_ckpts = max_keep_ckpts + self.save_last = save_last + self.args = kwargs + self.sync_buffer = sync_buffer + self.file_client_args = file_client_args + + def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + + runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' + f'{self.file_client.name}.')) + + # disable the create_symlink option because some file backends do not + # allow to create a symlink + if 'create_symlink' in self.args: + if self.args[ + 'create_symlink'] and not self.file_client.allow_symlink: + self.args['create_symlink'] = False + warnings.warn( + ('create_symlink is set as True by the user but is changed' + 'to be False because creating symbolic link is not ' + f'allowed in {self.file_client.name}')) + else: + self.args['create_symlink'] = self.file_client.allow_symlink + + def after_train_epoch(self, runner): + if not self.by_epoch: + return + + # save checkpoint for following cases: + # 1. every ``self.interval`` epochs + # 2. reach the last epoch of training + if self.every_n_epochs( + runner, self.interval) or (self.save_last + and self.is_last_epoch(runner)): + runner.logger.info( + f'Saving checkpoint at {runner.epoch + 1} epochs') + if self.sync_buffer: + allreduce_params(runner.model.buffers()) + self._save_checkpoint(runner) + + @master_only + def _save_checkpoint(self, runner): + """Save the current checkpoint and delete unwanted checkpoint.""" + runner.save_checkpoint( + self.out_dir, save_optimizer=self.save_optimizer, **self.args) + if runner.meta is not None: + if self.by_epoch: + cur_ckpt_filename = self.args.get( + 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) + else: + cur_ckpt_filename = self.args.get( + 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + runner.meta.setdefault('hook_msgs', dict()) + runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( + self.out_dir, cur_ckpt_filename) + # remove other checkpoints + if self.max_keep_ckpts > 0: + if self.by_epoch: + name = 'epoch_{}.pth' + current_ckpt = runner.epoch + 1 + else: + name = 'iter_{}.pth' + current_ckpt = runner.iter + 1 + redundant_ckpts = range( + current_ckpt - self.max_keep_ckpts * self.interval, 0, + -self.interval) + filename_tmpl = self.args.get('filename_tmpl', name) + for _step in redundant_ckpts: + ckpt_path = self.file_client.join_path( + self.out_dir, filename_tmpl.format(_step)) + if self.file_client.isfile(ckpt_path): + self.file_client.remove(ckpt_path) + else: + break + + def after_train_iter(self, runner): + if self.by_epoch: + return + + # save checkpoint for following cases: + # 1. every ``self.interval`` iterations + # 2. reach the last iteration of training + if self.every_n_iters( + runner, self.interval) or (self.save_last + and self.is_last_iter(runner)): + runner.logger.info( + f'Saving checkpoint at {runner.iter + 1} iterations') + if self.sync_buffer: + allreduce_params(runner.model.buffers()) + self._save_checkpoint(runner) diff --git a/annotator/uniformer/mmcv/runner/hooks/closure.py b/annotator/uniformer/mmcv/runner/hooks/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/closure.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class ClosureHook(Hook): + + def __init__(self, fn_name, fn): + assert hasattr(self, fn_name) + assert callable(fn) + setattr(self, fn_name, fn) diff --git a/annotator/uniformer/mmcv/runner/hooks/ema.py b/annotator/uniformer/mmcv/runner/hooks/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/ema.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...parallel import is_module_wrapper +from ..hooks.hook import HOOKS, Hook + + +@HOOKS.register_module() +class EMAHook(Hook): + r"""Exponential Moving Average Hook. + + Use Exponential Moving Average on all parameters of model in training + process. All parameters have a ema backup, which update by the formula + as below. EMAHook takes priority over EvalHook and CheckpointSaverHook. + + .. math:: + + \text{Xema\_{t+1}} = (1 - \text{momentum}) \times + \text{Xema\_{t}} + \text{momentum} \times X_t + + Args: + momentum (float): The momentum used for updating ema parameter. + Defaults to 0.0002. + interval (int): Update ema parameter every interval iteration. + Defaults to 1. + warm_up (int): During first warm_up steps, we may use smaller momentum + to update ema parameters more slowly. Defaults to 100. + resume_from (str): The checkpoint path. Defaults to None. + """ + + def __init__(self, + momentum=0.0002, + interval=1, + warm_up=100, + resume_from=None): + assert isinstance(interval, int) and interval > 0 + self.warm_up = warm_up + self.interval = interval + assert momentum > 0 and momentum < 1 + self.momentum = momentum**interval + self.checkpoint = resume_from + + def before_run(self, runner): + """To resume model with it's ema parameters more friendly. + + Register ema parameter as ``named_buffer`` to model + """ + model = runner.model + if is_module_wrapper(model): + model = model.module + self.param_ema_buffer = {} + self.model_parameters = dict(model.named_parameters(recurse=True)) + for name, value in self.model_parameters.items(): + # "." is not allowed in module's buffer name + buffer_name = f"ema_{name.replace('.', '_')}" + self.param_ema_buffer[name] = buffer_name + model.register_buffer(buffer_name, value.data.clone()) + self.model_buffers = dict(model.named_buffers(recurse=True)) + if self.checkpoint is not None: + runner.resume(self.checkpoint) + + def after_train_iter(self, runner): + """Update ema parameter every self.interval iterations.""" + curr_step = runner.iter + # We warm up the momentum considering the instability at beginning + momentum = min(self.momentum, + (1 + curr_step) / (self.warm_up + curr_step)) + if curr_step % self.interval != 0: + return + for name, parameter in self.model_parameters.items(): + buffer_name = self.param_ema_buffer[name] + buffer_parameter = self.model_buffers[buffer_name] + buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data) + + def after_train_epoch(self, runner): + """We load parameter values from ema backup to model before the + EvalHook.""" + self._swap_ema_parameters() + + def before_train_epoch(self, runner): + """We recover model's parameter from ema backup after last epoch's + EvalHook.""" + self._swap_ema_parameters() + + def _swap_ema_parameters(self): + """Swap the parameter of model with parameter in ema_buffer.""" + for name, value in self.model_parameters.items(): + temp = value.data.clone() + ema_buffer = self.model_buffers[self.param_ema_buffer[name]] + value.data.copy_(ema_buffer.data) + ema_buffer.data.copy_(temp) diff --git a/annotator/uniformer/mmcv/runner/hooks/evaluation.py b/annotator/uniformer/mmcv/runner/hooks/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..4d00999ce5665c53bded8de9e084943eee2d230d --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/evaluation.py @@ -0,0 +1,509 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from math import inf + +import torch.distributed as dist +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils.data import DataLoader + +from annotator.uniformer.mmcv.fileio import FileClient +from annotator.uniformer.mmcv.utils import is_seq_of +from .hook import Hook +from .logger import LoggerHook + + +class EvalHook(Hook): + """Non-Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in non-distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep + best score value and best checkpoint path, which will be also + loaded when resume checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader, and return the test results. If ``None``, the default + test function ``mmcv.engine.single_gpu_test`` will be used. + (default: ``None``) + greater_keys (List[str] | None, optional): Metric keys that will be + inferred by 'greater' comparison rule. If ``None``, + _default_greater_keys will be used. (default: ``None``) + less_keys (List[str] | None, optional): Metric keys that will be + inferred by 'less' comparison rule. If ``None``, _default_less_keys + will be used. (default: ``None``) + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + `New in version 1.3.16.` + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + `New in version 1.3.16.` + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + + Notes: + If new arguments are added for EvalHook, tools/test.py, + tools/eval_metric.py may be affected. + """ + + # Since the key for determine greater or less is related to the downstream + # tasks, downstream repos may need to overwrite the following inner + # variable accordingly. + + rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} + init_value_map = {'greater': -inf, 'less': inf} + _default_greater_keys = [ + 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', + 'mAcc', 'aAcc' + ] + _default_less_keys = ['loss'] + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=None, + less_keys=None, + out_dir=None, + file_client_args=None, + **eval_kwargs): + if not isinstance(dataloader, DataLoader): + raise TypeError(f'dataloader must be a pytorch DataLoader, ' + f'but got {type(dataloader)}') + + if interval <= 0: + raise ValueError(f'interval must be a positive number, ' + f'but got {interval}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean' + + if start is not None and start < 0: + raise ValueError(f'The evaluation start epoch {start} is smaller ' + f'than 0') + + self.dataloader = dataloader + self.interval = interval + self.start = start + self.by_epoch = by_epoch + + assert isinstance(save_best, str) or save_best is None, \ + '""save_best"" should be a str or None ' \ + f'rather than {type(save_best)}' + self.save_best = save_best + self.eval_kwargs = eval_kwargs + self.initial_flag = True + + if test_fn is None: + from annotator.uniformer.mmcv.engine import single_gpu_test + self.test_fn = single_gpu_test + else: + self.test_fn = test_fn + + if greater_keys is None: + self.greater_keys = self._default_greater_keys + else: + if not isinstance(greater_keys, (list, tuple)): + greater_keys = (greater_keys, ) + assert is_seq_of(greater_keys, str) + self.greater_keys = greater_keys + + if less_keys is None: + self.less_keys = self._default_less_keys + else: + if not isinstance(less_keys, (list, tuple)): + less_keys = (less_keys, ) + assert is_seq_of(less_keys, str) + self.less_keys = less_keys + + if self.save_best is not None: + self.best_ckpt_path = None + self._init_rule(rule, self.save_best) + + self.out_dir = out_dir + self.file_client_args = file_client_args + + def _init_rule(self, rule, key_indicator): + """Initialize rule, key_indicator, comparison_func, and best score. + + Here is the rule to determine which rule is used for key indicator + when the rule is not specific (note that the key indicator matching + is case-insensitive): + 1. If the key indicator is in ``self.greater_keys``, the rule will be + specified as 'greater'. + 2. Or if the key indicator is in ``self.less_keys``, the rule will be + specified as 'less'. + 3. Or if the key indicator is equal to the substring in any one item + in ``self.greater_keys``, the rule will be specified as 'greater'. + 4. Or if the key indicator is equal to the substring in any one item + in ``self.less_keys``, the rule will be specified as 'less'. + + Args: + rule (str | None): Comparison rule for best score. + key_indicator (str | None): Key indicator to determine the + comparison rule. + """ + if rule not in self.rule_map and rule is not None: + raise KeyError(f'rule must be greater, less or None, ' + f'but got {rule}.') + + if rule is None: + if key_indicator != 'auto': + # `_lc` here means we use the lower case of keys for + # case-insensitive matching + key_indicator_lc = key_indicator.lower() + greater_keys = [key.lower() for key in self.greater_keys] + less_keys = [key.lower() for key in self.less_keys] + + if key_indicator_lc in greater_keys: + rule = 'greater' + elif key_indicator_lc in less_keys: + rule = 'less' + elif any(key in key_indicator_lc for key in greater_keys): + rule = 'greater' + elif any(key in key_indicator_lc for key in less_keys): + rule = 'less' + else: + raise ValueError(f'Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + f'must be specified.') + self.rule = rule + self.key_indicator = key_indicator + if self.rule is not None: + self.compare_func = self.rule_map[self.rule] + + def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'The best checkpoint will be saved to {self.out_dir} by ' + f'{self.file_client.name}')) + + if self.save_best is not None: + if runner.meta is None: + warnings.warn('runner.meta is None. Creating an empty one.') + runner.meta = dict() + runner.meta.setdefault('hook_msgs', dict()) + self.best_ckpt_path = runner.meta['hook_msgs'].get( + 'best_ckpt', None) + + def before_train_iter(self, runner): + """Evaluate the model only at the start of training by iteration.""" + if self.by_epoch or not self.initial_flag: + return + if self.start is not None and runner.iter >= self.start: + self.after_train_iter(runner) + self.initial_flag = False + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + if not (self.by_epoch and self.initial_flag): + return + if self.start is not None and runner.epoch >= self.start: + self.after_train_epoch(runner) + self.initial_flag = False + + def after_train_iter(self, runner): + """Called after every training iter to evaluate the results.""" + if not self.by_epoch and self._should_evaluate(runner): + # Because the priority of EvalHook is higher than LoggerHook, the + # training log and the evaluating log are mixed. Therefore, + # we need to dump the training log and clear it before evaluating + # log is generated. In addition, this problem will only appear in + # `IterBasedRunner` whose `self.by_epoch` is False, because + # `EpochBasedRunner` whose `self.by_epoch` is True calls + # `_do_evaluate` in `after_train_epoch` stage, and at this stage + # the training log has been printed, so it will not cause any + # problem. more details at + # https://github.com/open-mmlab/mmsegmentation/issues/694 + for hook in runner._hooks: + if isinstance(hook, LoggerHook): + hook.after_train_iter(runner) + runner.log_buffer.clear() + + self._do_evaluate(runner) + + def after_train_epoch(self, runner): + """Called after every training epoch to evaluate the results.""" + if self.by_epoch and self._should_evaluate(runner): + self._do_evaluate(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + results = self.test_fn(runner.model, self.dataloader) + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to save + # the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) + + def _should_evaluate(self, runner): + """Judge whether to perform evaluation. + + Here is the rule to judge whether to perform evaluation: + 1. It will not perform evaluation during the epoch/iteration interval, + which is determined by ``self.interval``. + 2. It will not perform evaluation if the start time is larger than + current time. + 3. It will not perform evaluation when current time is larger than + the start time but during epoch/iteration interval. + + Returns: + bool: The flag indicating whether to perform evaluation. + """ + if self.by_epoch: + current = runner.epoch + check_time = self.every_n_epochs + else: + current = runner.iter + check_time = self.every_n_iters + + if self.start is None: + if not check_time(runner, self.interval): + # No evaluation during the interval. + return False + elif (current + 1) < self.start: + # No evaluation if start is larger than the current time. + return False + else: + # Evaluation only at epochs/iters 3, 5, 7... + # if start==3 and interval==2 + if (current + 1 - self.start) % self.interval: + return False + return True + + def _save_ckpt(self, runner, key_score): + """Save the best checkpoint. + + It will compare the score according to the compare function, write + related information (best score, best checkpoint path) and save the + best checkpoint into ``work_dir``. + """ + if self.by_epoch: + current = f'epoch_{runner.epoch + 1}' + cur_type, cur_time = 'epoch', runner.epoch + 1 + else: + current = f'iter_{runner.iter + 1}' + cur_type, cur_time = 'iter', runner.iter + 1 + + best_score = runner.meta['hook_msgs'].get( + 'best_score', self.init_value_map[self.rule]) + if self.compare_func(key_score, best_score): + best_score = key_score + runner.meta['hook_msgs']['best_score'] = best_score + + if self.best_ckpt_path and self.file_client.isfile( + self.best_ckpt_path): + self.file_client.remove(self.best_ckpt_path) + runner.logger.info( + (f'The previous best checkpoint {self.best_ckpt_path} was ' + 'removed')) + + best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' + self.best_ckpt_path = self.file_client.join_path( + self.out_dir, best_ckpt_name) + runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path + + runner.save_checkpoint( + self.out_dir, best_ckpt_name, create_symlink=False) + runner.logger.info( + f'Now best checkpoint is saved as {best_ckpt_name}.') + runner.logger.info( + f'Best {self.key_indicator} is {best_score:0.4f} ' + f'at {cur_time} {cur_type}.') + + def evaluate(self, runner, results): + """Evaluate the results. + + Args: + runner (:obj:`mmcv.Runner`): The underlined training runner. + results (list): Output results. + """ + eval_res = self.dataloader.dataset.evaluate( + results, logger=runner.logger, **self.eval_kwargs) + + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + if self.save_best is not None: + # If the performance of model is pool, the `eval_res` may be an + # empty dict and it will raise exception when `self.save_best` is + # not None. More details at + # https://github.com/open-mmlab/mmdetection/issues/6265. + if not eval_res: + warnings.warn( + 'Since `eval_res` is an empty dict, the behavior to save ' + 'the best checkpoint will be skipped in this evaluation.') + return None + + if self.key_indicator == 'auto': + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) + return eval_res[self.key_indicator] + + return None + + +class DistEvalHook(EvalHook): + """Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep + best score value and best checkpoint path, which will be also + loaded when resume checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader in a multi-gpu manner, and return the test results. If + ``None``, the default test function ``mmcv.engine.multi_gpu_test`` + will be used. (default: ``None``) + tmpdir (str | None): Temporary directory to save the results of all + processes. Default: None. + gpu_collect (bool): Whether to use gpu or cpu to collect results. + Default: False. + broadcast_bn_buffer (bool): Whether to broadcast the + buffer(running_mean and running_var) of rank 0 to other rank + before evaluation. Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + """ + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=None, + less_keys=None, + broadcast_bn_buffer=True, + tmpdir=None, + gpu_collect=False, + out_dir=None, + file_client_args=None, + **eval_kwargs): + + if test_fn is None: + from annotator.uniformer.mmcv.engine import multi_gpu_test + test_fn = multi_gpu_test + + super().__init__( + dataloader, + start=start, + interval=interval, + by_epoch=by_epoch, + save_best=save_best, + rule=rule, + test_fn=test_fn, + greater_keys=greater_keys, + less_keys=less_keys, + out_dir=out_dir, + file_client_args=file_client_args, + **eval_kwargs) + + self.broadcast_bn_buffer = broadcast_bn_buffer + self.tmpdir = tmpdir + self.gpu_collect = gpu_collect + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + results = self.test_fn( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to + # save the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) diff --git a/annotator/uniformer/mmcv/runner/hooks/hook.py b/annotator/uniformer/mmcv/runner/hooks/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b8855c107727ecf85b917c890fc8b7f6359238a4 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/hook.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from annotator.uniformer.mmcv.utils import Registry, is_method_overridden + +HOOKS = Registry('hook') + + +class Hook: + stages = ('before_run', 'before_train_epoch', 'before_train_iter', + 'after_train_iter', 'after_train_epoch', 'before_val_epoch', + 'before_val_iter', 'after_val_iter', 'after_val_epoch', + 'after_run') + + def before_run(self, runner): + pass + + def after_run(self, runner): + pass + + def before_epoch(self, runner): + pass + + def after_epoch(self, runner): + pass + + def before_iter(self, runner): + pass + + def after_iter(self, runner): + pass + + def before_train_epoch(self, runner): + self.before_epoch(runner) + + def before_val_epoch(self, runner): + self.before_epoch(runner) + + def after_train_epoch(self, runner): + self.after_epoch(runner) + + def after_val_epoch(self, runner): + self.after_epoch(runner) + + def before_train_iter(self, runner): + self.before_iter(runner) + + def before_val_iter(self, runner): + self.before_iter(runner) + + def after_train_iter(self, runner): + self.after_iter(runner) + + def after_val_iter(self, runner): + self.after_iter(runner) + + def every_n_epochs(self, runner, n): + return (runner.epoch + 1) % n == 0 if n > 0 else False + + def every_n_inner_iters(self, runner, n): + return (runner.inner_iter + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, runner, n): + return (runner.iter + 1) % n == 0 if n > 0 else False + + def end_of_epoch(self, runner): + return runner.inner_iter + 1 == len(runner.data_loader) + + def is_last_epoch(self, runner): + return runner.epoch + 1 == runner._max_epochs + + def is_last_iter(self, runner): + return runner.iter + 1 == runner._max_iters + + def get_triggered_stages(self): + trigger_stages = set() + for stage in Hook.stages: + if is_method_overridden(stage, Hook, self): + trigger_stages.add(stage) + + # some methods will be triggered in multi stages + # use this dict to map method to stages. + method_stages_map = { + 'before_epoch': ['before_train_epoch', 'before_val_epoch'], + 'after_epoch': ['after_train_epoch', 'after_val_epoch'], + 'before_iter': ['before_train_iter', 'before_val_iter'], + 'after_iter': ['after_train_iter', 'after_val_iter'], + } + + for method, map_stages in method_stages_map.items(): + if is_method_overridden(method, Hook, self): + trigger_stages.update(map_stages) + + return [stage for stage in Hook.stages if stage in trigger_stages] diff --git a/annotator/uniformer/mmcv/runner/hooks/iter_timer.py b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time + +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class IterTimerHook(Hook): + + def before_epoch(self, runner): + self.t = time.time() + + def before_iter(self, runner): + runner.log_buffer.update({'data_time': time.time() - self.t}) + + def after_iter(self, runner): + runner.log_buffer.update({'time': time.time() - self.t}) + self.t = time.time() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import LoggerHook +from .dvclive import DvcliveLoggerHook +from .mlflow import MlflowLoggerHook +from .neptune import NeptuneLoggerHook +from .pavi import PaviLoggerHook +from .tensorboard import TensorboardLoggerHook +from .text import TextLoggerHook +from .wandb import WandbLoggerHook + +__all__ = [ + 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', + 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', + 'NeptuneLoggerHook', 'DvcliveLoggerHook' +] diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/base.py b/annotator/uniformer/mmcv/runner/hooks/logger/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/base.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers +from abc import ABCMeta, abstractmethod + +import numpy as np +import torch + +from ..hook import Hook + + +class LoggerHook(Hook): + """Base class for logger hooks. + + Args: + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging. + by_epoch (bool): Whether EpochBasedRunner is used. + """ + + __metaclass__ = ABCMeta + + def __init__(self, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + self.interval = interval + self.ignore_last = ignore_last + self.reset_flag = reset_flag + self.by_epoch = by_epoch + + @abstractmethod + def log(self, runner): + pass + + @staticmethod + def is_scalar(val, include_np=True, include_torch=True): + """Tell the input variable is a scalar or not. + + Args: + val: Input variable. + include_np (bool): Whether include 0-d np.ndarray as a scalar. + include_torch (bool): Whether include 0-d torch.Tensor as a scalar. + + Returns: + bool: True or False. + """ + if isinstance(val, numbers.Number): + return True + elif include_np and isinstance(val, np.ndarray) and val.ndim == 0: + return True + elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1: + return True + else: + return False + + def get_mode(self, runner): + if runner.mode == 'train': + if 'time' in runner.log_buffer.output: + mode = 'train' + else: + mode = 'val' + elif runner.mode == 'val': + mode = 'val' + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return mode + + def get_epoch(self, runner): + if runner.mode == 'train': + epoch = runner.epoch + 1 + elif runner.mode == 'val': + # normal val mode + # runner.epoch += 1 has been done before val workflow + epoch = runner.epoch + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return epoch + + def get_iter(self, runner, inner_iter=False): + """Get the current training iteration step.""" + if self.by_epoch and inner_iter: + current_iter = runner.inner_iter + 1 + else: + current_iter = runner.iter + 1 + return current_iter + + def get_lr_tags(self, runner): + tags = {} + lrs = runner.current_lr() + if isinstance(lrs, dict): + for name, value in lrs.items(): + tags[f'learning_rate/{name}'] = value[0] + else: + tags['learning_rate'] = lrs[0] + return tags + + def get_momentum_tags(self, runner): + tags = {} + momentums = runner.current_momentum() + if isinstance(momentums, dict): + for name, value in momentums.items(): + tags[f'momentum/{name}'] = value[0] + else: + tags['momentum'] = momentums[0] + return tags + + def get_loggable_tags(self, + runner, + allow_scalar=True, + allow_text=False, + add_mode=True, + tags_to_skip=('time', 'data_time')): + tags = {} + for var, val in runner.log_buffer.output.items(): + if var in tags_to_skip: + continue + if self.is_scalar(val) and not allow_scalar: + continue + if isinstance(val, str) and not allow_text: + continue + if add_mode: + var = f'{self.get_mode(runner)}/{var}' + tags[var] = val + tags.update(self.get_lr_tags(runner)) + tags.update(self.get_momentum_tags(runner)) + return tags + + def before_run(self, runner): + for hook in runner.hooks[::-1]: + if isinstance(hook, LoggerHook): + hook.reset_flag = True + break + + def before_epoch(self, runner): + runner.log_buffer.clear() # clear logs of last epoch + + def after_train_iter(self, runner): + if self.by_epoch and self.every_n_inner_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif not self.by_epoch and self.every_n_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif self.end_of_epoch(runner) and not self.ignore_last: + # not precise but more stable + runner.log_buffer.average(self.interval) + + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_train_epoch(self, runner): + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_val_epoch(self, runner): + runner.log_buffer.average() + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py new file mode 100644 index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class DvcliveLoggerHook(LoggerHook): + """Class to log metrics with dvclive. + + It requires `dvclive`_ to be installed. + + Args: + path (str): Directory where dvclive will write TSV log files. + interval (int): Logging interval (every k iterations). + Default 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: True. + by_epoch (bool): Whether EpochBasedRunner is used. + Default: True. + + .. _dvclive: + https://dvc.org/doc/dvclive + """ + + def __init__(self, + path, + interval=10, + ignore_last=True, + reset_flag=True, + by_epoch=True): + + super(DvcliveLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.path = path + self.import_dvclive() + + def import_dvclive(self): + try: + import dvclive + except ImportError: + raise ImportError( + 'Please run "pip install dvclive" to install dvclive') + self.dvclive = dvclive + + @master_only + def before_run(self, runner): + self.dvclive.init(self.path) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + for k, v in tags.items(): + self.dvclive.log(k, v, step=self.get_iter(runner)) diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class MlflowLoggerHook(LoggerHook): + + def __init__(self, + exp_name=None, + tags=None, + log_model=True, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + """Class to log metrics and (optionally) a trained model to MLflow. + + It requires `MLflow`_ to be installed. + + Args: + exp_name (str, optional): Name of the experiment to be used. + Default None. + If not None, set the active experiment. + If experiment does not exist, an experiment with provided name + will be created. + tags (dict of str: str, optional): Tags for the current run. + Default None. + If not None, set tags for the current run. + log_model (bool, optional): Whether to log an MLflow artifact. + Default True. + If True, log runner.model as an MLflow artifact + for the current run. + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging + by_epoch (bool): Whether EpochBasedRunner is used. + + .. _MLflow: + https://www.mlflow.org/docs/latest/index.html + """ + super(MlflowLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_mlflow() + self.exp_name = exp_name + self.tags = tags + self.log_model = log_model + + def import_mlflow(self): + try: + import mlflow + import mlflow.pytorch as mlflow_pytorch + except ImportError: + raise ImportError( + 'Please run "pip install mlflow" to install mlflow') + self.mlflow = mlflow + self.mlflow_pytorch = mlflow_pytorch + + @master_only + def before_run(self, runner): + super(MlflowLoggerHook, self).before_run(runner) + if self.exp_name is not None: + self.mlflow.set_experiment(self.exp_name) + if self.tags is not None: + self.mlflow.set_tags(self.tags) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + self.mlflow.log_metrics(tags, step=self.get_iter(runner)) + + @master_only + def after_run(self, runner): + if self.log_model: + self.mlflow_pytorch.log_model(runner.model, 'models') diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py new file mode 100644 index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class NeptuneLoggerHook(LoggerHook): + """Class to log metrics to NeptuneAI. + + It requires `neptune-client` to be installed. + + Args: + init_kwargs (dict): a dict contains the initialization keys as below: + - project (str): Name of a project in a form of + namespace/project_name. If None, the value of + NEPTUNE_PROJECT environment variable will be taken. + - api_token (str): User’s API token. + If None, the value of NEPTUNE_API_TOKEN environment + variable will be taken. Note: It is strongly recommended + to use NEPTUNE_API_TOKEN environment variable rather than + placing your API token in plain text in your source code. + - name (str, optional, default is 'Untitled'): Editable name of + the run. Name is displayed in the run's Details and in + Runs table as a column. + Check https://docs.neptune.ai/api-reference/neptune#init for + more init arguments. + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging + by_epoch (bool): Whether EpochBasedRunner is used. + + .. _NeptuneAI: + https://docs.neptune.ai/you-should-know/logging-metadata + """ + + def __init__(self, + init_kwargs=None, + interval=10, + ignore_last=True, + reset_flag=True, + with_step=True, + by_epoch=True): + + super(NeptuneLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_neptune() + self.init_kwargs = init_kwargs + self.with_step = with_step + + def import_neptune(self): + try: + import neptune.new as neptune + except ImportError: + raise ImportError( + 'Please run "pip install neptune-client" to install neptune') + self.neptune = neptune + self.run = None + + @master_only + def before_run(self, runner): + if self.init_kwargs: + self.run = self.neptune.init(**self.init_kwargs) + else: + self.run = self.neptune.init() + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + for tag_name, tag_value in tags.items(): + if self.with_step: + self.run[tag_name].log( + tag_value, step=self.get_iter(runner)) + else: + tags['global_step'] = self.get_iter(runner) + self.run[tag_name].log(tags) + + @master_only + def after_run(self, runner): + self.run.stop() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py new file mode 100644 index 0000000000000000000000000000000000000000..1dcf146d8163aff1363e9764999b0a74d674a595 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import os.path as osp + +import torch +import yaml + +import annotator.uniformer.mmcv as mmcv +from ....parallel.utils import is_module_wrapper +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class PaviLoggerHook(LoggerHook): + + def __init__(self, + init_kwargs=None, + add_graph=False, + add_last_ckpt=False, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True, + img_key='img_info'): + super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) + self.init_kwargs = init_kwargs + self.add_graph = add_graph + self.add_last_ckpt = add_last_ckpt + self.img_key = img_key + + @master_only + def before_run(self, runner): + super(PaviLoggerHook, self).before_run(runner) + try: + from pavi import SummaryWriter + except ImportError: + raise ImportError('Please run "pip install pavi" to install pavi.') + + self.run_name = runner.work_dir.split('/')[-1] + + if not self.init_kwargs: + self.init_kwargs = dict() + self.init_kwargs['name'] = self.run_name + self.init_kwargs['model'] = runner._model_name + if runner.meta is not None: + if 'config_dict' in runner.meta: + config_dict = runner.meta['config_dict'] + assert isinstance( + config_dict, + dict), ('meta["config_dict"] has to be of a dict, ' + f'but got {type(config_dict)}') + elif 'config_file' in runner.meta: + config_file = runner.meta['config_file'] + config_dict = dict(mmcv.Config.fromfile(config_file)) + else: + config_dict = None + if config_dict is not None: + # 'max_.*iter' is parsed in pavi sdk as the maximum iterations + # to properly set up the progress bar. + config_dict = config_dict.copy() + config_dict.setdefault('max_iter', runner.max_iters) + # non-serializable values are first converted in + # mmcv.dump to json + config_dict = json.loads( + mmcv.dump(config_dict, file_format='json')) + session_text = yaml.dump(config_dict) + self.init_kwargs['session_text'] = session_text + self.writer = SummaryWriter(**self.init_kwargs) + + def get_step(self, runner): + """Get the total training step/epoch.""" + if self.get_mode(runner) == 'val' and self.by_epoch: + return self.get_epoch(runner) + else: + return self.get_iter(runner) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner, add_mode=False) + if tags: + self.writer.add_scalars( + self.get_mode(runner), tags, self.get_step(runner)) + + @master_only + def after_run(self, runner): + if self.add_last_ckpt: + ckpt_path = osp.join(runner.work_dir, 'latest.pth') + if osp.islink(ckpt_path): + ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path)) + + if osp.isfile(ckpt_path): + # runner.epoch += 1 has been done before `after_run`. + iteration = runner.epoch if self.by_epoch else runner.iter + return self.writer.add_snapshot_file( + tag=self.run_name, + snapshot_file_path=ckpt_path, + iteration=iteration) + + # flush the buffer and send a task ending signal to Pavi + self.writer.close() + + @master_only + def before_epoch(self, runner): + if runner.epoch == 0 and self.add_graph: + if is_module_wrapper(runner.model): + _model = runner.model.module + else: + _model = runner.model + device = next(_model.parameters()).device + data = next(iter(runner.data_loader)) + image = data[self.img_key][0:1].to(device) + with torch.no_grad(): + self.writer.add_graph(_model, image) diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5011dc08def6c09eef86d3ce5b124c9fc5372 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class TensorboardLoggerHook(LoggerHook): + + def __init__(self, + log_dir=None, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + super(TensorboardLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.log_dir = log_dir + + @master_only + def before_run(self, runner): + super(TensorboardLoggerHook, self).before_run(runner) + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.1')): + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError('Please install tensorboardX to use ' + 'TensorboardLoggerHook.') + else: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + 'Please run "pip install future tensorboard" to install ' + 'the dependencies to use torch.utils.tensorboard ' + '(applicable to PyTorch 1.1 or higher)') + + if self.log_dir is None: + self.log_dir = osp.join(runner.work_dir, 'tf_logs') + self.writer = SummaryWriter(self.log_dir) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner, allow_text=True) + for tag, val in tags.items(): + if isinstance(val, str): + self.writer.add_text(tag, val, self.get_iter(runner)) + else: + self.writer.add_scalar(tag, val, self.get_iter(runner)) + + @master_only + def after_run(self, runner): + self.writer.close() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/text.py b/annotator/uniformer/mmcv/runner/hooks/logger/text.py new file mode 100644 index 0000000000000000000000000000000000000000..87b1a3eca9595a130121526f8b4c29915387ab35 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/text.py @@ -0,0 +1,256 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import os +import os.path as osp +from collections import OrderedDict + +import torch +import torch.distributed as dist + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.fileio.file_client import FileClient +from annotator.uniformer.mmcv.utils import is_tuple_of, scandir +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class TextLoggerHook(LoggerHook): + """Logger hook in text. + + In this logger hook, the information will be printed on terminal and + saved in json file. + + Args: + by_epoch (bool, optional): Whether EpochBasedRunner is used. + Default: True. + interval (int, optional): Logging interval (every k iterations). + Default: 10. + ignore_last (bool, optional): Ignore the log of last iterations in each + epoch if less than :attr:`interval`. Default: True. + reset_flag (bool, optional): Whether to clear the output buffer after + logging. Default: False. + interval_exp_name (int, optional): Logging interval for experiment + name. This feature is to help users conveniently get the experiment + information from screen or log file. Default: 1000. + out_dir (str, optional): Logs are saved in ``runner.work_dir`` default. + If ``out_dir`` is specified, logs will be copied to a new directory + which is the concatenation of ``out_dir`` and the last level + directory of ``runner.work_dir``. Default: None. + `New in version 1.3.16.` + out_suffix (str or tuple[str], optional): Those filenames ending with + ``out_suffix`` will be copied to ``out_dir``. + Default: ('.log.json', '.log', '.py'). + `New in version 1.3.16.` + keep_local (bool, optional): Whether to keep local log when + :attr:`out_dir` is specified. If False, the local log will be + removed. Default: True. + `New in version 1.3.16.` + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + """ + + def __init__(self, + by_epoch=True, + interval=10, + ignore_last=True, + reset_flag=False, + interval_exp_name=1000, + out_dir=None, + out_suffix=('.log.json', '.log', '.py'), + keep_local=True, + file_client_args=None): + super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) + self.by_epoch = by_epoch + self.time_sec_tot = 0 + self.interval_exp_name = interval_exp_name + + if out_dir is None and file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" when `out_dir` is not' + 'specified.') + self.out_dir = out_dir + + if not (out_dir is None or isinstance(out_dir, str) + or is_tuple_of(out_dir, str)): + raise TypeError('out_dir should be "None" or string or tuple of ' + 'string, but got {out_dir}') + self.out_suffix = out_suffix + + self.keep_local = keep_local + self.file_client_args = file_client_args + if self.out_dir is not None: + self.file_client = FileClient.infer_client(file_client_args, + self.out_dir) + + def before_run(self, runner): + super(TextLoggerHook, self).before_run(runner) + + if self.out_dir is not None: + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'Text logs will be saved to {self.out_dir} by ' + f'{self.file_client.name} after the training process.')) + + self.start_iter = runner.iter + self.json_log_path = osp.join(runner.work_dir, + f'{runner.timestamp}.log.json') + if runner.meta is not None: + self._dump_log(runner.meta, runner) + + def _get_max_memory(self, runner): + device = getattr(runner.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([mem / (1024 * 1024)], + dtype=torch.int, + device=device) + if runner.world_size > 1: + dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) + return mem_mb.item() + + def _log_info(self, log_dict, runner): + # print exp name for users to distinguish experiments + # at every ``interval_exp_name`` iterations and the end of each epoch + if runner.meta is not None and 'exp_name' in runner.meta: + if (self.every_n_iters(runner, self.interval_exp_name)) or ( + self.by_epoch and self.end_of_epoch(runner)): + exp_info = f'Exp name: {runner.meta["exp_name"]}' + runner.logger.info(exp_info) + + if log_dict['mode'] == 'train': + if isinstance(log_dict['lr'], dict): + lr_str = [] + for k, val in log_dict['lr'].items(): + lr_str.append(f'lr_{k}: {val:.3e}') + lr_str = ' '.join(lr_str) + else: + lr_str = f'lr: {log_dict["lr"]:.3e}' + + # by epoch: Epoch [4][100/1000] + # by iter: Iter [100/100000] + if self.by_epoch: + log_str = f'Epoch [{log_dict["epoch"]}]' \ + f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' + else: + log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t' + log_str += f'{lr_str}, ' + + if 'time' in log_dict.keys(): + self.time_sec_tot += (log_dict['time'] * self.interval) + time_sec_avg = self.time_sec_tot / ( + runner.iter - self.start_iter + 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + log_str += f'eta: {eta_str}, ' + log_str += f'time: {log_dict["time"]:.3f}, ' \ + f'data_time: {log_dict["data_time"]:.3f}, ' + # statistic memory + if torch.cuda.is_available(): + log_str += f'memory: {log_dict["memory"]}, ' + else: + # val/test time + # here 1000 is the length of the val dataloader + # by epoch: Epoch[val] [4][1000] + # by iter: Iter[val] [1000] + if self.by_epoch: + log_str = f'Epoch({log_dict["mode"]}) ' \ + f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t' + else: + log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' + + log_items = [] + for name, val in log_dict.items(): + # TODO: resolve this hack + # these items have been in log_str + if name in [ + 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time', + 'memory', 'epoch' + ]: + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + + runner.logger.info(log_str) + + def _dump_log(self, log_dict, runner): + # dump log in json format + json_log = OrderedDict() + for k, v in log_dict.items(): + json_log[k] = self._round_float(v) + # only append log at last line + if runner.rank == 0: + with open(self.json_log_path, 'a+') as f: + mmcv.dump(json_log, f, file_format='json') + f.write('\n') + + def _round_float(self, items): + if isinstance(items, list): + return [self._round_float(item) for item in items] + elif isinstance(items, float): + return round(items, 5) + else: + return items + + def log(self, runner): + if 'eval_iter_num' in runner.log_buffer.output: + # this doesn't modify runner.iter and is regardless of by_epoch + cur_iter = runner.log_buffer.output.pop('eval_iter_num') + else: + cur_iter = self.get_iter(runner, inner_iter=True) + + log_dict = OrderedDict( + mode=self.get_mode(runner), + epoch=self.get_epoch(runner), + iter=cur_iter) + + # only record lr of the first param group + cur_lr = runner.current_lr() + if isinstance(cur_lr, list): + log_dict['lr'] = cur_lr[0] + else: + assert isinstance(cur_lr, dict) + log_dict['lr'] = {} + for k, lr_ in cur_lr.items(): + assert isinstance(lr_, list) + log_dict['lr'].update({k: lr_[0]}) + + if 'time' in runner.log_buffer.output: + # statistic memory + if torch.cuda.is_available(): + log_dict['memory'] = self._get_max_memory(runner) + + log_dict = dict(log_dict, **runner.log_buffer.output) + + self._log_info(log_dict, runner) + self._dump_log(log_dict, runner) + return log_dict + + def after_run(self, runner): + # copy or upload logs to self.out_dir + if self.out_dir is not None: + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + out_filepath = self.file_client.join_path( + self.out_dir, filename) + with open(local_filepath, 'r') as f: + self.file_client.put_text(f.read(), out_filepath) + + runner.logger.info( + (f'The file {local_filepath} has been uploaded to ' + f'{out_filepath}.')) + + if not self.keep_local: + os.remove(local_filepath) + runner.logger.info( + (f'{local_filepath} was removed due to the ' + '`self.keep_local=False`')) diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class WandbLoggerHook(LoggerHook): + + def __init__(self, + init_kwargs=None, + interval=10, + ignore_last=True, + reset_flag=False, + commit=True, + by_epoch=True, + with_step=True): + super(WandbLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_wandb() + self.init_kwargs = init_kwargs + self.commit = commit + self.with_step = with_step + + def import_wandb(self): + try: + import wandb + except ImportError: + raise ImportError( + 'Please run "pip install wandb" to install wandb') + self.wandb = wandb + + @master_only + def before_run(self, runner): + super(WandbLoggerHook, self).before_run(runner) + if self.wandb is None: + self.import_wandb() + if self.init_kwargs: + self.wandb.init(**self.init_kwargs) + else: + self.wandb.init() + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + if self.with_step: + self.wandb.log( + tags, step=self.get_iter(runner), commit=self.commit) + else: + tags['global_step'] = self.get_iter(runner) + self.wandb.log(tags, commit=self.commit) + + @master_only + def after_run(self, runner): + self.wandb.join() diff --git a/annotator/uniformer/mmcv/runner/hooks/lr_updater.py b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..6365908ddf6070086de2ffc0afada46ed2f32256 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py @@ -0,0 +1,670 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers +from math import cos, pi + +import annotator.uniformer.mmcv as mmcv +from .hook import HOOKS, Hook + + +class LrUpdaterHook(Hook): + """LR Scheduler in MMCV. + + Args: + by_epoch (bool): LR changes epoch by epoch + warmup (string): Type of warmup used. It can be None(use no warmup), + 'constant', 'linear' or 'exp' + warmup_iters (int): The number of iterations or epochs that warmup + lasts + warmup_ratio (float): LR used at the beginning of warmup equals to + warmup_ratio * initial_lr + warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters + means the number of epochs that warmup lasts, otherwise means the + number of iteration that warmup lasts + """ + + def __init__(self, + by_epoch=True, + warmup=None, + warmup_iters=0, + warmup_ratio=0.1, + warmup_by_epoch=False): + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + f'"{warmup}" is not a supported type for warming up, valid' + ' types are "constant" and "linear"') + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_ratio" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + self.warmup_by_epoch = warmup_by_epoch + + if self.warmup_by_epoch: + self.warmup_epochs = self.warmup_iters + self.warmup_iters = None + else: + self.warmup_epochs = None + + self.base_lr = [] # initial lr for all param groups + self.regular_lr = [] # expected lr if no warming up is performed + + def _set_lr(self, runner, lr_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, lr in zip(optim.param_groups, lr_groups[k]): + param_group['lr'] = lr + else: + for param_group, lr in zip(runner.optimizer.param_groups, + lr_groups): + param_group['lr'] = lr + + def get_lr(self, runner, base_lr): + raise NotImplementedError + + def get_regular_lr(self, runner): + if isinstance(runner.optimizer, dict): + lr_groups = {} + for k in runner.optimizer.keys(): + _lr_group = [ + self.get_lr(runner, _base_lr) + for _base_lr in self.base_lr[k] + ] + lr_groups.update({k: _lr_group}) + + return lr_groups + else: + return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] + + def get_warmup_lr(self, cur_iters): + + def _get_warmup_lr(cur_iters, regular_lr): + if self.warmup == 'constant': + warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - + self.warmup_ratio) + warmup_lr = [_lr * (1 - k) for _lr in regular_lr] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_lr = [_lr * k for _lr in regular_lr] + return warmup_lr + + if isinstance(self.regular_lr, dict): + lr_groups = {} + for key, regular_lr in self.regular_lr.items(): + lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr) + return lr_groups + else: + return _get_warmup_lr(cur_iters, self.regular_lr) + + def before_run(self, runner): + # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, + # it will be set according to the optimizer params + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + for group in optim.param_groups: + group.setdefault('initial_lr', group['lr']) + _base_lr = [ + group['initial_lr'] for group in optim.param_groups + ] + self.base_lr.update({k: _base_lr}) + else: + for group in runner.optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + self.base_lr = [ + group['initial_lr'] for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner): + if self.warmup_iters is None: + epoch_len = len(runner.data_loader) + self.warmup_iters = self.warmup_epochs * epoch_len + + if not self.by_epoch: + return + + self.regular_lr = self.get_regular_lr(runner) + self._set_lr(runner, self.regular_lr) + + def before_train_iter(self, runner): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_lr = self.get_regular_lr(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + + +@HOOKS.register_module() +class FixedLrUpdaterHook(LrUpdaterHook): + + def __init__(self, **kwargs): + super(FixedLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + return base_lr + + +@HOOKS.register_module() +class StepLrUpdaterHook(LrUpdaterHook): + """Step LR scheduler with min_lr clipping. + + Args: + step (int | list[int]): Step to decay the LR. If an int value is given, + regard it as the decay interval. If a list is given, decay LR at + these steps. + gamma (float, optional): Decay LR ratio. Default: 0.1. + min_lr (float, optional): Minimum LR value to keep. If LR after decay + is lower than `min_lr`, it will be clipped to this value. If None + is given, we don't perform lr clipping. Default: None. + """ + + def __init__(self, step, gamma=0.1, min_lr=None, **kwargs): + if isinstance(step, list): + assert mmcv.is_list_of(step, int) + assert all([s > 0 for s in step]) + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + self.min_lr = min_lr + super(StepLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + + # calculate exponential term + if isinstance(self.step, int): + exp = progress // self.step + else: + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + + lr = base_lr * (self.gamma**exp) + if self.min_lr is not None: + # clip to a minimum value + lr = max(lr, self.min_lr) + return lr + + +@HOOKS.register_module() +class ExpLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma, **kwargs): + self.gamma = gamma + super(ExpLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * self.gamma**progress + + +@HOOKS.register_module() +class PolyLrUpdaterHook(LrUpdaterHook): + + def __init__(self, power=1., min_lr=0., **kwargs): + self.power = power + self.min_lr = min_lr + super(PolyLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + coeff = (1 - progress / max_progress)**self.power + return (base_lr - self.min_lr) * coeff + self.min_lr + + +@HOOKS.register_module() +class InvLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma, power=1., **kwargs): + self.gamma = gamma + self.power = power + super(InvLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * (1 + self.gamma * progress)**(-self.power) + + +@HOOKS.register_module() +class CosineAnnealingLrUpdaterHook(LrUpdaterHook): + + def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr + return annealing_cos(base_lr, target_lr, progress / max_progress) + + +@HOOKS.register_module() +class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook): + """Flat + Cosine lr schedule. + + Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501 + + Args: + start_percent (float): When to start annealing the learning rate + after the percentage of the total training steps. + The value should be in range [0, 1). + Default: 0.75 + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + start_percent=0.75, + min_lr=None, + min_lr_ratio=None, + **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + if start_percent < 0 or start_percent > 1 or not isinstance( + start_percent, float): + raise ValueError( + 'expected float between 0 and 1 start_percent, but ' + f'got {start_percent}') + self.start_percent = start_percent + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + start = round(runner.max_epochs * self.start_percent) + progress = runner.epoch - start + max_progress = runner.max_epochs - start + else: + start = round(runner.max_iters * self.start_percent) + progress = runner.iter - start + max_progress = runner.max_iters - start + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr + + if progress < 0: + return base_lr + else: + return annealing_cos(base_lr, target_lr, progress / max_progress) + + +@HOOKS.register_module() +class CosineRestartLrUpdaterHook(LrUpdaterHook): + """Cosine annealing with restarts learning rate scheme. + + Args: + periods (list[int]): Periods for each cosine anneling cycle. + restart_weights (list[float], optional): Restart weights at each + restart iteration. Default: [1]. + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + periods, + restart_weights=[1], + min_lr=None, + min_lr_ratio=None, + **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + self.periods = periods + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + self.restart_weights = restart_weights + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + super(CosineRestartLrUpdaterHook, self).__init__(**kwargs) + + self.cumulative_periods = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + else: + progress = runner.iter + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr + + idx = get_position_from_periods(progress, self.cumulative_periods) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] + current_periods = self.periods[idx] + + alpha = min((progress - nearest_restart) / current_periods, 1) + return annealing_cos(base_lr, target_lr, alpha, current_weight) + + +def get_position_from_periods(iteration, cumulative_periods): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_periods = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 3. + + Args: + iteration (int): Current iteration. + cumulative_periods (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_periods): + if iteration < period: + return i + raise ValueError(f'Current iteration {iteration} exceeds ' + f'cumulative_periods {cumulative_periods}') + + +@HOOKS.register_module() +class CyclicLrUpdaterHook(LrUpdaterHook): + """Cyclic LR Scheduler. + + Implement the cyclical learning rate policy (CLR) described in + https://arxiv.org/pdf/1506.01186.pdf + + Different from the original paper, we use cosine annealing rather than + triangular policy inside a cycle. This improves the performance in the + 3D detection area. + + Args: + by_epoch (bool): Whether to update LR by epoch. + target_ratio (tuple[float]): Relative ratio of the highest LR and the + lowest LR to the initial LR. + cyclic_times (int): Number of cycles during training + step_ratio_up (float): The ratio of the increasing process of LR in + the total cycle. + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. Default: 'cos'. + """ + + def __init__(self, + by_epoch=False, + target_ratio=(10, 1e-4), + cyclic_times=1, + step_ratio_up=0.4, + anneal_strategy='cos', + **kwargs): + if isinstance(target_ratio, float): + target_ratio = (target_ratio, target_ratio / 1e5) + elif isinstance(target_ratio, tuple): + target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \ + if len(target_ratio) == 1 else target_ratio + else: + raise ValueError('target_ratio should be either float ' + f'or tuple, got {type(target_ratio)}') + + assert len(target_ratio) == 2, \ + '"target_ratio" must be list or tuple of two floats' + assert 0 <= step_ratio_up < 1.0, \ + '"step_ratio_up" must be in range [0,1)' + + self.target_ratio = target_ratio + self.cyclic_times = cyclic_times + self.step_ratio_up = step_ratio_up + self.lr_phases = [] # init lr_phases + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + + assert not by_epoch, \ + 'currently only support "by_epoch" = False' + super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs) + + def before_run(self, runner): + super(CyclicLrUpdaterHook, self).before_run(runner) + # initiate lr_phases + # total lr_phases are separated as up and down + max_iter_per_phase = runner.max_iters // self.cyclic_times + iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) + self.lr_phases.append( + [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]]) + self.lr_phases.append([ + iter_up_phase, max_iter_per_phase, max_iter_per_phase, + self.target_ratio[0], self.target_ratio[1] + ]) + + def get_lr(self, runner, base_lr): + curr_iter = runner.iter + for (start_iter, end_iter, max_iter_per_phase, start_ratio, + end_ratio) in self.lr_phases: + curr_iter %= max_iter_per_phase + if start_iter <= curr_iter < end_iter: + progress = curr_iter - start_iter + return self.anneal_func(base_lr * start_ratio, + base_lr * end_ratio, + progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleLrUpdaterHook(LrUpdaterHook): + """One Cycle LR Scheduler. + + The 1cycle learning rate policy changes the learning rate after every + batch. The one cycle learning rate policy is described in + https://arxiv.org/pdf/1708.07120.pdf + + Args: + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int, optional): The total number of steps in the cycle. + Note that if a value is not provided here, it will be the max_iter + of runner. Default: None. + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + max_lr, + total_steps=None, + pct_start=0.3, + anneal_strategy='cos', + div_factor=25, + final_div_factor=1e4, + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch = False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(max_lr, (numbers.Number, list, dict)): + raise ValueError('the type of max_lr must be the one of list or ' + f'dict, but got {type(max_lr)}') + self._max_lr = max_lr + if total_steps is not None: + if not isinstance(total_steps, int): + raise ValueError('the type of total_steps must be int, but' + f'got {type(total_steps)}') + self.total_steps = total_steps + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.div_factor = div_factor + self.final_div_factor = final_div_factor + self.three_phase = three_phase + self.lr_phases = [] # init lr_phases + super(OneCycleLrUpdaterHook, self).__init__(**kwargs) + + def before_run(self, runner): + if hasattr(self, 'total_steps'): + total_steps = self.total_steps + else: + total_steps = runner.max_iters + if total_steps < runner.max_iters: + raise ValueError( + 'The total steps must be greater than or equal to max ' + f'iterations {runner.max_iters} of runner, but total steps ' + f'is {total_steps}.') + + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + _max_lr = format_param(k, optim, self._max_lr) + self.base_lr[k] = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(optim.param_groups, self.base_lr[k]): + group.setdefault('initial_lr', lr) + else: + k = type(runner.optimizer).__name__ + _max_lr = format_param(k, runner.optimizer, self._max_lr) + self.base_lr = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(runner.optimizer.param_groups, self.base_lr): + group.setdefault('initial_lr', lr) + + if self.three_phase: + self.lr_phases.append( + [float(self.pct_start * total_steps) - 1, 1, self.div_factor]) + self.lr_phases.append([ + float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1 + ]) + self.lr_phases.append( + [total_steps - 1, 1, 1 / self.final_div_factor]) + else: + self.lr_phases.append( + [float(self.pct_start * total_steps) - 1, 1, self.div_factor]) + self.lr_phases.append( + [total_steps - 1, self.div_factor, 1 / self.final_div_factor]) + + def get_lr(self, runner, base_lr): + curr_iter = runner.iter + start_iter = 0 + for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases): + if curr_iter <= end_iter: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr, + pct) + break + start_iter = end_iter + return lr + + +def annealing_cos(start, end, factor, weight=1): + """Calculate annealing cos learning rate. + + Cosine anneal from `weight * start + (1 - weight) * end` to `end` as + percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the cosine annealing. + end (float): The ending learing rate of the cosine annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + weight (float, optional): The combination factor of `start` and `end` + when calculating the actual starting learning rate. Default to 1. + """ + cos_out = cos(pi * factor) + 1 + return end + 0.5 * weight * (start - end) * cos_out + + +def annealing_linear(start, end, factor): + """Calculate annealing linear learning rate. + + Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the linear annealing. + end (float): The ending learing rate of the linear annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + """ + return start + (end - start) * factor + + +def format_param(name, optim, param): + if isinstance(param, numbers.Number): + return [param] * len(optim.param_groups) + elif isinstance(param, (list, tuple)): # multi param groups + if len(param) != len(optim.param_groups): + raise ValueError(f'expected {len(optim.param_groups)} ' + f'values for {name}, got {len(param)}') + return param + else: # multi optimizers + if name not in param: + raise KeyError(f'{name} is not found in {param.keys()}') + return param[name] diff --git a/annotator/uniformer/mmcv/runner/hooks/memory.py b/annotator/uniformer/mmcv/runner/hooks/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/memory.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class EmptyCacheHook(Hook): + + def __init__(self, before_epoch=False, after_epoch=True, after_iter=False): + self._before_epoch = before_epoch + self._after_epoch = after_epoch + self._after_iter = after_iter + + def after_iter(self, runner): + if self._after_iter: + torch.cuda.empty_cache() + + def before_epoch(self, runner): + if self._before_epoch: + torch.cuda.empty_cache() + + def after_epoch(self, runner): + if self._after_epoch: + torch.cuda.empty_cache() diff --git a/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..60437756ceedf06055ec349df69a25465738d3f0 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py @@ -0,0 +1,493 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import annotator.uniformer.mmcv as mmcv +from .hook import HOOKS, Hook +from .lr_updater import annealing_cos, annealing_linear, format_param + + +class MomentumUpdaterHook(Hook): + + def __init__(self, + by_epoch=True, + warmup=None, + warmup_iters=0, + warmup_ratio=0.9): + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + f'"{warmup}" is not a supported type for warming up, valid' + ' types are "constant" and "linear"') + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_momentum" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + + self.base_momentum = [] # initial momentum for all param groups + self.regular_momentum = [ + ] # expected momentum if no warming up is performed + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, base_momentum): + raise NotImplementedError + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k in runner.optimizer.keys(): + _momentum_group = [ + self.get_momentum(runner, _base_momentum) + for _base_momentum in self.base_momentum[k] + ] + momentum_groups.update({k: _momentum_group}) + return momentum_groups + else: + return [ + self.get_momentum(runner, _base_momentum) + for _base_momentum in self.base_momentum + ] + + def get_warmup_momentum(self, cur_iters): + + def _get_warmup_momentum(cur_iters, regular_momentum): + if self.warmup == 'constant': + warmup_momentum = [ + _momentum / self.warmup_ratio + for _momentum in self.regular_momentum + ] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - + self.warmup_ratio) + warmup_momentum = [ + _momentum / (1 - k) for _momentum in self.regular_mom + ] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_momentum = [ + _momentum / k for _momentum in self.regular_mom + ] + return warmup_momentum + + if isinstance(self.regular_momentum, dict): + momentum_groups = {} + for key, regular_momentum in self.regular_momentum.items(): + momentum_groups[key] = _get_warmup_momentum( + cur_iters, regular_momentum) + return momentum_groups + else: + return _get_warmup_momentum(cur_iters, self.regular_momentum) + + def before_run(self, runner): + # NOTE: when resuming from a checkpoint, + # if 'initial_momentum' is not saved, + # it will be set according to the optimizer params + if isinstance(runner.optimizer, dict): + self.base_momentum = {} + for k, optim in runner.optimizer.items(): + for group in optim.param_groups: + if 'momentum' in group.keys(): + group.setdefault('initial_momentum', group['momentum']) + else: + group.setdefault('initial_momentum', group['betas'][0]) + _base_momentum = [ + group['initial_momentum'] for group in optim.param_groups + ] + self.base_momentum.update({k: _base_momentum}) + else: + for group in runner.optimizer.param_groups: + if 'momentum' in group.keys(): + group.setdefault('initial_momentum', group['momentum']) + else: + group.setdefault('initial_momentum', group['betas'][0]) + self.base_momentum = [ + group['initial_momentum'] + for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner): + if not self.by_epoch: + return + self.regular_mom = self.get_regular_momentum(runner) + self._set_momentum(runner, self.regular_mom) + + def before_train_iter(self, runner): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_mom = self.get_regular_momentum(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_momentum(runner, self.regular_mom) + else: + warmup_momentum = self.get_warmup_momentum(cur_iter) + self._set_momentum(runner, warmup_momentum) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_momentum(runner, self.regular_mom) + else: + warmup_momentum = self.get_warmup_momentum(cur_iter) + self._set_momentum(runner, warmup_momentum) + + +@HOOKS.register_module() +class StepMomentumUpdaterHook(MomentumUpdaterHook): + """Step momentum scheduler with min value clipping. + + Args: + step (int | list[int]): Step to decay the momentum. If an int value is + given, regard it as the decay interval. If a list is given, decay + momentum at these steps. + gamma (float, optional): Decay momentum ratio. Default: 0.5. + min_momentum (float, optional): Minimum momentum value to keep. If + momentum after decay is lower than this value, it will be clipped + accordingly. If None is given, we don't perform lr clipping. + Default: None. + """ + + def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs): + if isinstance(step, list): + assert mmcv.is_list_of(step, int) + assert all([s > 0 for s in step]) + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + self.min_momentum = min_momentum + super(StepMomentumUpdaterHook, self).__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + progress = runner.epoch if self.by_epoch else runner.iter + + # calculate exponential term + if isinstance(self.step, int): + exp = progress // self.step + else: + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + + momentum = base_momentum * (self.gamma**exp) + if self.min_momentum is not None: + # clip to a minimum value + momentum = max(momentum, self.min_momentum) + return momentum + + +@HOOKS.register_module() +class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): + + def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): + assert (min_momentum is None) ^ (min_momentum_ratio is None) + self.min_momentum = min_momentum + self.min_momentum_ratio = min_momentum_ratio + super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + if self.min_momentum_ratio is not None: + target_momentum = base_momentum * self.min_momentum_ratio + else: + target_momentum = self.min_momentum + return annealing_cos(base_momentum, target_momentum, + progress / max_progress) + + +@HOOKS.register_module() +class CyclicMomentumUpdaterHook(MomentumUpdaterHook): + """Cyclic momentum Scheduler. + + Implement the cyclical momentum scheduler policy described in + https://arxiv.org/pdf/1708.07120.pdf + + This momentum scheduler usually used together with the CyclicLRUpdater + to improve the performance in the 3D detection area. + + Attributes: + target_ratio (tuple[float]): Relative ratio of the lowest momentum and + the highest momentum to the initial momentum. + cyclic_times (int): Number of cycles during training + step_ratio_up (float): The ratio of the increasing process of momentum + in the total cycle. + by_epoch (bool): Whether to update momentum by epoch. + """ + + def __init__(self, + by_epoch=False, + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4, + **kwargs): + if isinstance(target_ratio, float): + target_ratio = (target_ratio, target_ratio / 1e5) + elif isinstance(target_ratio, tuple): + target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \ + if len(target_ratio) == 1 else target_ratio + else: + raise ValueError('target_ratio should be either float ' + f'or tuple, got {type(target_ratio)}') + + assert len(target_ratio) == 2, \ + '"target_ratio" must be list or tuple of two floats' + assert 0 <= step_ratio_up < 1.0, \ + '"step_ratio_up" must be in range [0,1)' + + self.target_ratio = target_ratio + self.cyclic_times = cyclic_times + self.step_ratio_up = step_ratio_up + self.momentum_phases = [] # init momentum_phases + # currently only support by_epoch=False + assert not by_epoch, \ + 'currently only support "by_epoch" = False' + super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs) + + def before_run(self, runner): + super(CyclicMomentumUpdaterHook, self).before_run(runner) + # initiate momentum_phases + # total momentum_phases are separated as up and down + max_iter_per_phase = runner.max_iters // self.cyclic_times + iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) + self.momentum_phases.append( + [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]]) + self.momentum_phases.append([ + iter_up_phase, max_iter_per_phase, max_iter_per_phase, + self.target_ratio[0], self.target_ratio[1] + ]) + + def get_momentum(self, runner, base_momentum): + curr_iter = runner.iter + for (start_iter, end_iter, max_iter_per_phase, start_ratio, + end_ratio) in self.momentum_phases: + curr_iter %= max_iter_per_phase + if start_iter <= curr_iter < end_iter: + progress = curr_iter - start_iter + return annealing_cos(base_momentum * start_ratio, + base_momentum * end_ratio, + progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): + """OneCycle momentum Scheduler. + + This momentum scheduler usually used together with the OneCycleLrUpdater + to improve the performance. + + Args: + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is + 'max_momentum' and learning rate is 'base_lr' + Default: 0.95 + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + base_momentum=0.85, + max_momentum=0.95, + pct_start=0.3, + anneal_strategy='cos', + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch=False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(base_momentum, (float, list, dict)): + raise ValueError('base_momentum must be the type among of float,' + 'list or dict.') + self._base_momentum = base_momentum + if not isinstance(max_momentum, (float, list, dict)): + raise ValueError('max_momentum must be the type among of float,' + 'list or dict.') + self._max_momentum = max_momentum + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('Expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must by one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.three_phase = three_phase + self.momentum_phases = [] # init momentum_phases + super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs) + + def before_run(self, runner): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip( + optim.param_groups, _base_momentum, _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + else: + optim = runner.optimizer + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + k = type(optim).__name__ + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip(optim.param_groups, + _base_momentum, + _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + + if self.three_phase: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': + float(2 * self.pct_start * runner.max_iters) - 2, + 'start_momentum': + 'base_momentum', + 'end_momentum': + 'max_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum' + }) + else: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum' + }) + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, param_group): + curr_iter = runner.iter + start_iter = 0 + for i, phase in enumerate(self.momentum_phases): + end_iter = phase['end_iter'] + if curr_iter <= end_iter or i == len(self.momentum_phases) - 1: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + momentum = self.anneal_func( + param_group[phase['start_momentum']], + param_group[phase['end_momentum']], pct) + break + start_iter = end_iter + return momentum + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k, optim in runner.optimizer.items(): + _momentum_group = [ + self.get_momentum(runner, param_group) + for param_group in optim.param_groups + ] + momentum_groups.update({k: _momentum_group}) + return momentum_groups + else: + momentum_groups = [] + for param_group in runner.optimizer.param_groups: + momentum_groups.append(self.get_momentum(runner, param_group)) + return momentum_groups diff --git a/annotator/uniformer/mmcv/runner/hooks/optimizer.py b/annotator/uniformer/mmcv/runner/hooks/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef3e9ff8f9c6926e32bdf027612267b64ed80df --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/optimizer.py @@ -0,0 +1,508 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import defaultdict +from itertools import chain + +from torch.nn.utils import clip_grad + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version +from ..dist_utils import allreduce_grads +from ..fp16_utils import LossScaler, wrap_fp16_model +from .hook import HOOKS, Hook + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + from torch.cuda.amp import GradScaler +except ImportError: + pass + + +@HOOKS.register_module() +class OptimizerHook(Hook): + + def __init__(self, grad_clip=None): + self.grad_clip = grad_clip + + def clip_grads(self, params): + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return clip_grad.clip_grad_norm_(params, **self.grad_clip) + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + runner.outputs['loss'].backward() + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + + +@HOOKS.register_module() +class GradientCumulativeOptimizerHook(OptimizerHook): + """Optimizer Hook implements multi-iters gradient cumulating. + + Args: + cumulative_iters (int, optional): Num of gradient cumulative iters. + The optimizer will step every `cumulative_iters` iters. + Defaults to 1. + + Examples: + >>> # Use cumulative_iters to simulate a large batch size + >>> # It is helpful when the hardware cannot handle a large batch size. + >>> loader = DataLoader(data, batch_size=64) + >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4) + >>> # almost equals to + >>> loader = DataLoader(data, batch_size=256) + >>> optim_hook = OptimizerHook() + """ + + def __init__(self, cumulative_iters=1, **kwargs): + super(GradientCumulativeOptimizerHook, self).__init__(**kwargs) + + assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ + f'cumulative_iters only accepts positive int, but got ' \ + f'{type(cumulative_iters)} instead.' + + self.cumulative_iters = cumulative_iters + self.divisible_iters = 0 + self.remainder_iters = 0 + self.initialized = False + + def has_batch_norm(self, module): + if isinstance(module, _BatchNorm): + return True + for m in module.children(): + if self.has_batch_norm(m): + return True + return False + + def _init(self, runner): + if runner.iter % self.cumulative_iters != 0: + runner.logger.warning( + 'Resume iter number is not divisible by cumulative_iters in ' + 'GradientCumulativeOptimizerHook, which means the gradient of ' + 'some iters is lost and the result may be influenced slightly.' + ) + + if self.has_batch_norm(runner.model) and self.cumulative_iters > 1: + runner.logger.warning( + 'GradientCumulativeOptimizerHook may slightly decrease ' + 'performance if the model has BatchNorm layers.') + + residual_iters = runner.max_iters - runner.iter + + self.divisible_iters = ( + residual_iters // self.cumulative_iters * self.cumulative_iters) + self.remainder_iters = residual_iters - self.divisible_iters + + self.initialized = True + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + runner.optimizer.zero_grad() + + +if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (using PyTorch's implementation). + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of GradScalar. + Defaults to 512. For Pytorch >= 1.6, mmcv uses official + implementation of GradScaler. If you use a dict version of + loss_scale to create GradScaler, please refer to: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler + for the parameters. + + Examples: + >>> loss_scale = dict( + ... init_scale=65536.0, + ... growth_factor=2.0, + ... backoff_factor=0.5, + ... growth_interval=2000 + ... ) + >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale) + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + self._scale_update_param = None + if loss_scale == 'dynamic': + self.loss_scaler = GradScaler() + elif isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.loss_scaler = GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.loss_scaler = GradScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training.""" + # wrap model mode to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer to + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients. + 3. Unscale the optimizer’s gradient tensors. + 4. Call optimizer.step() and update scale factor. + 5. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + + self.loss_scaler.scale(runner.outputs['loss']).backward() + self.loss_scaler.unscale_(runner.optimizer) + # grad clip + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using PyTorch's implementation) implements + multi-iters gradient cumulating. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + """ + + def __init__(self, *args, **kwargs): + super(GradientCumulativeFp16OptimizerHook, + self).__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + + self.loss_scaler.scale(loss).backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + self.loss_scaler.unscale_(runner.optimizer) + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() + +else: + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (mmcv's implementation). + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of LossScaler. + Defaults to 512. + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + if loss_scale == 'dynamic': + self.loss_scaler = LossScaler(mode='dynamic') + elif isinstance(loss_scale, float): + self.loss_scaler = LossScaler( + init_scale=loss_scale, mode='static') + elif isinstance(loss_scale, dict): + self.loss_scaler = LossScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training. + + 1. Make a master copy of fp32 weights for optimization. + 2. Convert the main model from fp32 to fp16. + """ + # keep a copy of fp32 weights + old_groups = runner.optimizer.param_groups + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + state = defaultdict(dict) + p_map = { + old_p: p + for old_p, p in zip( + chain(*(g['params'] for g in old_groups)), + chain(*(g['params'] + for g in runner.optimizer.param_groups))) + } + for k, v in runner.optimizer.state.items(): + state[p_map[k]] = v + runner.optimizer.state = state + # convert model to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer `loss_scalar.py` + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients (fp16). + 3. Copy gradients from the model to the fp32 weight copy. + 4. Scale the gradients back and update the fp32 weight copy. + 5. Copy back the params from fp32 weight copy to the fp16 model. + 6. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + self.loss_scaler.update_scale(has_overflow) + if has_overflow: + runner.logger.warning('Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using mmcv implementation) implements multi- + iters gradient cumulating.""" + + def __init__(self, *args, **kwargs): + super(GradientCumulativeFp16OptimizerHook, + self).__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + + loss = runner.outputs['loss'] + loss = loss / loss_factor + + # scale the loss value + scaled_loss = loss * self.loss_scaler.loss_scale + scaled_loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + else: + runner.logger.warning( + 'Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + self.loss_scaler.update_scale(has_overflow) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() diff --git a/annotator/uniformer/mmcv/runner/hooks/profiler.py b/annotator/uniformer/mmcv/runner/hooks/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/profiler.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Callable, List, Optional, Union + +import torch + +from ..dist_utils import master_only +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class ProfilerHook(Hook): + """Profiler to analyze performance during training. + + PyTorch Profiler is a tool that allows the collection of the performance + metrics during the training. More details on Profiler can be found at + https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile + + Args: + by_epoch (bool): Profile performance by epoch or by iteration. + Default: True. + profile_iters (int): Number of iterations for profiling. + If ``by_epoch=True``, profile_iters indicates that they are the + first profile_iters epochs at the beginning of the + training, otherwise it indicates the first profile_iters + iterations. Default: 1. + activities (list[str]): List of activity groups (CPU, CUDA) to use in + profiling. Default: ['cpu', 'cuda']. + schedule (dict, optional): Config of generating the callable schedule. + if schedule is None, profiler will not add step markers into the + trace and table view. Default: None. + on_trace_ready (callable, dict): Either a handler or a dict of generate + handler. Default: None. + record_shapes (bool): Save information about operator's input shapes. + Default: False. + profile_memory (bool): Track tensor memory allocation/deallocation. + Default: False. + with_stack (bool): Record source information (file and line number) + for the ops. Default: False. + with_flops (bool): Use formula to estimate the FLOPS of specific + operators (matrix multiplication and 2D convolution). + Default: False. + json_trace_path (str, optional): Exports the collected trace in Chrome + JSON format. Default: None. + + Example: + >>> runner = ... # instantiate a Runner + >>> # tensorboard trace + >>> trace_config = dict(type='tb_trace', dir_name='work_dir') + >>> profiler_config = dict(on_trace_ready=trace_config) + >>> runner.register_profiler_hook(profiler_config) + >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)]) + """ + + def __init__(self, + by_epoch: bool = True, + profile_iters: int = 1, + activities: List[str] = ['cpu', 'cuda'], + schedule: Optional[dict] = None, + on_trace_ready: Optional[Union[Callable, dict]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + json_trace_path: Optional[str] = None) -> None: + try: + from torch import profiler # torch version >= 1.8.1 + except ImportError: + raise ImportError('profiler is the new feature of torch1.8.1, ' + f'but your version is {torch.__version__}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.' + self.by_epoch = by_epoch + + if profile_iters < 1: + raise ValueError('profile_iters should be greater than 0, but got ' + f'{profile_iters}') + self.profile_iters = profile_iters + + if not isinstance(activities, list): + raise ValueError( + f'activities should be list, but got {type(activities)}') + self.activities = [] + for activity in activities: + activity = activity.lower() + if activity == 'cpu': + self.activities.append(profiler.ProfilerActivity.CPU) + elif activity == 'cuda': + self.activities.append(profiler.ProfilerActivity.CUDA) + else: + raise ValueError( + f'activity should be "cpu" or "cuda", but got {activity}') + + if schedule is not None: + self.schedule = profiler.schedule(**schedule) + else: + self.schedule = None + + self.on_trace_ready = on_trace_ready + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_flops = with_flops + self.json_trace_path = json_trace_path + + @master_only + def before_run(self, runner): + if self.by_epoch and runner.max_epochs < self.profile_iters: + raise ValueError('self.profile_iters should not be greater than ' + f'{runner.max_epochs}') + + if not self.by_epoch and runner.max_iters < self.profile_iters: + raise ValueError('self.profile_iters should not be greater than ' + f'{runner.max_iters}') + + if callable(self.on_trace_ready): # handler + _on_trace_ready = self.on_trace_ready + elif isinstance(self.on_trace_ready, dict): # config of handler + trace_cfg = self.on_trace_ready.copy() + trace_type = trace_cfg.pop('type') # log_trace handler + if trace_type == 'log_trace': + + def _log_handler(prof): + print(prof.key_averages().table(**trace_cfg)) + + _on_trace_ready = _log_handler + elif trace_type == 'tb_trace': # tensorboard_trace handler + try: + import torch_tb_profiler # noqa: F401 + except ImportError: + raise ImportError('please run "pip install ' + 'torch-tb-profiler" to install ' + 'torch_tb_profiler') + _on_trace_ready = torch.profiler.tensorboard_trace_handler( + **trace_cfg) + else: + raise ValueError('trace_type should be "log_trace" or ' + f'"tb_trace", but got {trace_type}') + elif self.on_trace_ready is None: + _on_trace_ready = None # type: ignore + else: + raise ValueError('on_trace_ready should be handler, dict or None, ' + f'but got {type(self.on_trace_ready)}') + + if runner.max_epochs > 1: + warnings.warn(f'profiler will profile {runner.max_epochs} epochs ' + 'instead of 1 epoch. Since profiler will slow down ' + 'the training, it is recommended to train 1 epoch ' + 'with ProfilerHook and adjust your setting according' + ' to the profiler summary. During normal training ' + '(epoch > 1), you may disable the ProfilerHook.') + + self.profiler = torch.profiler.profile( + activities=self.activities, + schedule=self.schedule, + on_trace_ready=_on_trace_ready, + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_flops=self.with_flops) + + self.profiler.__enter__() + runner.logger.info('profiler is profiling...') + + @master_only + def after_train_epoch(self, runner): + if self.by_epoch and runner.epoch == self.profile_iters - 1: + runner.logger.info('profiler may take a few minutes...') + self.profiler.__exit__(None, None, None) + if self.json_trace_path is not None: + self.profiler.export_chrome_trace(self.json_trace_path) + + @master_only + def after_train_iter(self, runner): + self.profiler.step() + if not self.by_epoch and runner.iter == self.profile_iters - 1: + runner.logger.info('profiler may take a few minutes...') + self.profiler.__exit__(None, None, None) + if self.json_trace_path is not None: + self.profiler.export_chrome_trace(self.json_trace_path) diff --git a/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class DistSamplerSeedHook(Hook): + """Data-loading sampler for distributed training. + + When distributed training, it is only useful in conjunction with + :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same + purpose with :obj:`IterLoader`. + """ + + def before_epoch(self, runner): + if hasattr(runner.data_loader.sampler, 'set_epoch'): + # in case the data loader uses `SequentialSampler` in Pytorch + runner.data_loader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): + # batch sampler in pytorch warps the sampler as its attributes. + runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..dist_utils import allreduce_params +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class SyncBuffersHook(Hook): + """Synchronize model buffers such as running_mean and running_var in BN at + the end of each epoch. + + Args: + distributed (bool): Whether distributed training is used. It is + effective only for distributed training. Defaults to True. + """ + + def __init__(self, distributed=True): + self.distributed = distributed + + def after_epoch(self, runner): + """All-reduce model buffers at the end of each epoch.""" + if self.distributed: + allreduce_params(runner.model.buffers()) diff --git a/annotator/uniformer/mmcv/runner/iter_based_runner.py b/annotator/uniformer/mmcv/runner/iter_based_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..1df4de8c0285669dec9b014dfd1f3dd1600f0831 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/iter_based_runner.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import shutil +import time +import warnings + +import torch +from torch.optim import Optimizer + +import annotator.uniformer.mmcv as mmcv +from .base_runner import BaseRunner +from .builder import RUNNERS +from .checkpoint import save_checkpoint +from .hooks import IterTimerHook +from .utils import get_host_info + + +class IterLoader: + + def __init__(self, dataloader): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._epoch = 0 + + @property + def epoch(self): + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, 'set_epoch'): + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __len__(self): + return len(self._dataloader) + + +@RUNNERS.register_module() +class IterBasedRunner(BaseRunner): + """Iteration-based Runner. + + This runner train models iteration by iteration. + """ + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._epoch = data_loader.epoch + data_batch = next(data_loader) + self.call_hook('before_train_iter') + outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.train_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_train_iter') + self._inner_iter += 1 + self._iter += 1 + + @torch.no_grad() + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + data_batch = next(data_loader) + self.call_hook('before_val_iter') + outputs = self.model.val_step(data_batch, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.val_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_val_iter') + self._inner_iter += 1 + + def run(self, data_loaders, workflow, max_iters=None, **kwargs): + """Start running. + + Args: + data_loaders (list[:obj:`DataLoader`]): Dataloaders for training + and validation. + workflow (list[tuple]): A list of (phase, iters) to specify the + running order and iterations. E.g, [('train', 10000), + ('val', 1000)] means running 10000 iterations for training and + 1000 iterations for validation, iteratively. + """ + assert isinstance(data_loaders, list) + assert mmcv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + if max_iters is not None: + warnings.warn( + 'setting max_iters in run is deprecated, ' + 'please set max_iters in runner_config', DeprecationWarning) + self._max_iters = max_iters + assert self._max_iters is not None, ( + 'max_iters must be specified during instantiation') + + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) + self.logger.info('workflow: %s, max: %d iters', workflow, + self._max_iters) + self.call_hook('before_run') + + iter_loaders = [IterLoader(x) for x in data_loaders] + + self.call_hook('before_epoch') + + while self.iter < self._max_iters: + for i, flow in enumerate(workflow): + self._inner_iter = 0 + mode, iters = flow + if not isinstance(mode, str) or not hasattr(self, mode): + raise ValueError( + 'runner has no method named "{}" to run a workflow'. + format(mode)) + iter_runner = getattr(self, mode) + for _ in range(iters): + if mode == 'train' and self.iter >= self._max_iters: + break + iter_runner(iter_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_epoch') + self.call_hook('after_run') + + def resume(self, + checkpoint, + resume_optimizer=True, + map_location='default'): + """Resume model from checkpoint. + + Args: + checkpoint (str): Checkpoint to resume from. + resume_optimizer (bool, optional): Whether resume the optimizer(s) + if the checkpoint file includes optimizer(s). Default to True. + map_location (str, optional): Same as :func:`torch.load`. + Default to 'default'. + """ + if map_location == 'default': + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + self._inner_iter = checkpoint['meta']['iter'] + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') + + def save_checkpoint(self, + out_dir, + filename_tmpl='iter_{}.pth', + meta=None, + save_optimizer=True, + create_symlink=True): + """Save checkpoint to file. + + Args: + out_dir (str): Directory to save checkpoint files. + filename_tmpl (str, optional): Checkpoint file template. + Defaults to 'iter_{}.pth'. + meta (dict, optional): Metadata to be saved in checkpoint. + Defaults to None. + save_optimizer (bool, optional): Whether save optimizer. + Defaults to True. + create_symlink (bool, optional): Whether create symlink to the + latest checkpoint file. Defaults to True. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + # Note: meta.update(self.meta) should be done before + # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise + # there will be problems with resumed checkpoints. + # More details in https://github.com/open-mmlab/mmcv/pull/1108 + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = filename_tmpl.format(self.iter + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + mmcv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + def register_training_hooks(self, + lr_config, + optimizer_config=None, + checkpoint_config=None, + log_config=None, + momentum_config=None, + custom_hooks_config=None): + """Register default hooks for iter-based training. + + Checkpoint hook, optimizer stepper hook and logger hooks will be set to + `by_epoch=False` by default. + + Default hooks include: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. + """ + if checkpoint_config is not None: + checkpoint_config.setdefault('by_epoch', False) + if lr_config is not None: + lr_config.setdefault('by_epoch', False) + if log_config is not None: + for info in log_config['hooks']: + info.setdefault('by_epoch', False) + super(IterBasedRunner, self).register_training_hooks( + lr_config=lr_config, + momentum_config=momentum_config, + optimizer_config=optimizer_config, + checkpoint_config=checkpoint_config, + log_config=log_config, + timer_config=IterTimerHook(), + custom_hooks_config=custom_hooks_config) diff --git a/annotator/uniformer/mmcv/runner/log_buffer.py b/annotator/uniformer/mmcv/runner/log_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/log_buffer.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import numpy as np + + +class LogBuffer: + + def __init__(self): + self.val_history = OrderedDict() + self.n_history = OrderedDict() + self.output = OrderedDict() + self.ready = False + + def clear(self): + self.val_history.clear() + self.n_history.clear() + self.clear_output() + + def clear_output(self): + self.output.clear() + self.ready = False + + def update(self, vars, count=1): + assert isinstance(vars, dict) + for key, var in vars.items(): + if key not in self.val_history: + self.val_history[key] = [] + self.n_history[key] = [] + self.val_history[key].append(var) + self.n_history[key].append(count) + + def average(self, n=0): + """Average latest n values or all values.""" + assert n >= 0 + for key in self.val_history: + values = np.array(self.val_history[key][-n:]) + nums = np.array(self.n_history[key][-n:]) + avg = np.sum(values * nums) / np.sum(nums) + self.output[key] = avg + self.ready = True diff --git a/annotator/uniformer/mmcv/runner/optimizer/__init__.py b/annotator/uniformer/mmcv/runner/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d --- /dev/null +++ b/annotator/uniformer/mmcv/runner/optimizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer, + build_optimizer_constructor) +from .default_constructor import DefaultOptimizerConstructor + +__all__ = [ + 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', + 'build_optimizer', 'build_optimizer_constructor' +] diff --git a/annotator/uniformer/mmcv/runner/optimizer/builder.py b/annotator/uniformer/mmcv/runner/optimizer/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a --- /dev/null +++ b/annotator/uniformer/mmcv/runner/optimizer/builder.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect + +import torch + +from ...utils import Registry, build_from_cfg + +OPTIMIZERS = Registry('optimizer') +OPTIMIZER_BUILDERS = Registry('optimizer builder') + + +def register_torch_optimizers(): + torch_optimizers = [] + for module_name in dir(torch.optim): + if module_name.startswith('__'): + continue + _optim = getattr(torch.optim, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module()(_optim) + torch_optimizers.append(module_name) + return torch_optimizers + + +TORCH_OPTIMIZERS = register_torch_optimizers() + + +def build_optimizer_constructor(cfg): + return build_from_cfg(cfg, OPTIMIZER_BUILDERS) + + +def build_optimizer(model, cfg): + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0da3503b75441738efe38d70352b55a210a34a --- /dev/null +++ b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py @@ -0,0 +1,249 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +from torch.nn import GroupNorm, LayerNorm + +from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of +from annotator.uniformer.mmcv.utils.ext_loader import check_ops_exist +from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS + + +@OPTIMIZER_BUILDERS.register_module() +class DefaultOptimizerConstructor: + """Default constructor for optimizers. + + By default each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Default: False. + + Note: + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset + layer. So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the + offset layer in deformable convs, set ``dcn_offset_lr_mult`` + to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when + the model contains multiple DCN layers in places other than + backbone. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + optimizer_cfg (dict): The config dict of the optimizer. + Positional fields are + + - `type`: class name of the optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, + >>> weight_decay=0.0001) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) + >>> paramwise_cfg = dict(custom_keys={ + '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def __init__(self, optimizer_cfg, paramwise_cfg=None): + if not isinstance(optimizer_cfg, dict): + raise TypeError('optimizer_cfg should be a dict', + f'but got {type(optimizer_cfg)}') + self.optimizer_cfg = optimizer_cfg + self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg + self.base_lr = optimizer_cfg.get('lr', None) + self.base_wd = optimizer_cfg.get('weight_decay', None) + self._validate_cfg() + + def _validate_cfg(self): + if not isinstance(self.paramwise_cfg, dict): + raise TypeError('paramwise_cfg should be None or a dict, ' + f'but got {type(self.paramwise_cfg)}') + + if 'custom_keys' in self.paramwise_cfg: + if not isinstance(self.paramwise_cfg['custom_keys'], dict): + raise TypeError( + 'If specified, custom_keys must be a dict, ' + f'but got {type(self.paramwise_cfg["custom_keys"])}') + if self.base_wd is None: + for key in self.paramwise_cfg['custom_keys']: + if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]: + raise ValueError('base_wd should not be None') + + # get base lr and weight decay + # weight_decay must be explicitly specified if mult is specified + if ('bias_decay_mult' in self.paramwise_cfg + or 'norm_decay_mult' in self.paramwise_cfg + or 'dwconv_decay_mult' in self.paramwise_cfg): + if self.base_wd is None: + raise ValueError('base_wd should not be None') + + def _is_in(self, param_group, param_group_list): + assert is_list_of(param_group_list, dict) + param = set(param_group['params']) + param_set = set() + for group in param_group_list: + param_set.update(set(group['params'])) + + return not param.isdisjoint(param_set) + + def add_params(self, params, module, prefix='', is_dcn_module=None): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + if bypass_duplicate and self._is_in(param_group, params): + warnings.warn(f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}') + continue + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + break + + if not is_custom: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not (is_norm or is_dcn_module): + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # depth-wise conv + elif is_dwconv: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # bias lr and decay + elif name == 'bias' and not is_dcn_module: + # TODO: current bias_decay_mult will have affect on DCN + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + params.append(param_group) + + if check_ops_exist(): + from annotator.uniformer.mmcv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) + + def __call__(self, model): + if hasattr(model, 'module'): + model = model.module + + optimizer_cfg = self.optimizer_cfg.copy() + # if no paramwise option is specified, just use the global setting + if not self.paramwise_cfg: + optimizer_cfg['params'] = model.parameters() + return build_from_cfg(optimizer_cfg, OPTIMIZERS) + + # set param-wise lr and weight decay recursively + params = [] + self.add_params(params, model) + optimizer_cfg['params'] = params + + return build_from_cfg(optimizer_cfg, OPTIMIZERS) diff --git a/annotator/uniformer/mmcv/runner/priority.py b/annotator/uniformer/mmcv/runner/priority.py new file mode 100644 index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/priority.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from enum import Enum + + +class Priority(Enum): + """Hook priority levels. + + +--------------+------------+ + | Level | Value | + +==============+============+ + | HIGHEST | 0 | + +--------------+------------+ + | VERY_HIGH | 10 | + +--------------+------------+ + | HIGH | 30 | + +--------------+------------+ + | ABOVE_NORMAL | 40 | + +--------------+------------+ + | NORMAL | 50 | + +--------------+------------+ + | BELOW_NORMAL | 60 | + +--------------+------------+ + | LOW | 70 | + +--------------+------------+ + | VERY_LOW | 90 | + +--------------+------------+ + | LOWEST | 100 | + +--------------+------------+ + """ + + HIGHEST = 0 + VERY_HIGH = 10 + HIGH = 30 + ABOVE_NORMAL = 40 + NORMAL = 50 + BELOW_NORMAL = 60 + LOW = 70 + VERY_LOW = 90 + LOWEST = 100 + + +def get_priority(priority): + """Get priority value. + + Args: + priority (int or str or :obj:`Priority`): Priority. + + Returns: + int: The priority value. + """ + if isinstance(priority, int): + if priority < 0 or priority > 100: + raise ValueError('priority must be between 0 and 100') + return priority + elif isinstance(priority, Priority): + return priority.value + elif isinstance(priority, str): + return Priority[priority.upper()].value + else: + raise TypeError('priority must be an integer or Priority enum value') diff --git a/annotator/uniformer/mmcv/runner/utils.py b/annotator/uniformer/mmcv/runner/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5befb8e56ece50b5fecfd007b26f8a29124c0bd --- /dev/null +++ b/annotator/uniformer/mmcv/runner/utils.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import random +import sys +import time +import warnings +from getpass import getuser +from socket import gethostname + +import numpy as np +import torch + +import annotator.uniformer.mmcv as mmcv + + +def get_host_info(): + """Get hostname and username. + + Return empty string if exception raised, e.g. ``getpass.getuser()`` will + lead to error in docker container + """ + host = '' + try: + host = f'{getuser()}@{gethostname()}' + except Exception as e: + warnings.warn(f'Host or user not found: {str(e)}') + finally: + return host + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def obj_from_dict(info, parent=None, default_args=None): + """Initialize an object from dict. + + The dict must contain the key "type", which indicates the object type, it + can be either a string or type, such as "list" or ``list``. Remaining + fields are treated as the arguments for constructing the object. + + Args: + info (dict): Object types and arguments. + parent (:class:`module`): Module which may containing expected object + classes. + default_args (dict, optional): Default arguments for initializing the + object. + + Returns: + any type: Object built from the dict. + """ + assert isinstance(info, dict) and 'type' in info + assert isinstance(default_args, dict) or default_args is None + args = info.copy() + obj_type = args.pop('type') + if mmcv.is_str(obj_type): + if parent is not None: + obj_type = getattr(parent, obj_type) + else: + obj_type = sys.modules[obj_type] + elif not isinstance(obj_type, type): + raise TypeError('type must be a str or valid type, but ' + f'got {type(obj_type)}') + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return obj_type(**args) + + +def set_random_seed(seed, deterministic=False, use_rank_shift=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + rank_shift (bool): Whether to add rank number to the random seed to + have different random seed in different threads. Default: False. + """ + if use_rank_shift: + rank, _ = mmcv.runner.get_dist_info() + seed += rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/annotator/uniformer/mmcv/utils/__init__.py b/annotator/uniformer/mmcv/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/__init__.py @@ -0,0 +1,69 @@ +# flake8: noqa +# Copyright (c) OpenMMLab. All rights reserved. +from .config import Config, ConfigDict, DictAction +from .misc import (check_prerequisites, concat_list, deprecated_api_warning, + has_method, import_modules_from_strings, is_list_of, + is_method_overridden, is_seq_of, is_str, is_tuple_of, + iter_cast, list_cast, requires_executable, requires_package, + slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, + to_ntuple, tuple_cast) +from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, + scandir, symlink) +from .progressbar import (ProgressBar, track_iter_progress, + track_parallel_progress, track_progress) +from .testing import (assert_attrs_equal, assert_dict_contains_subset, + assert_dict_has_keys, assert_is_norm_layer, + assert_keys_equal, assert_params_all_zeros, + check_python_script) +from .timer import Timer, TimerError, check_time +from .version_utils import digit_version, get_git_hash + +try: + import torch +except ImportError: + __all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast', + 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of', + 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package', + 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist', + 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', + 'track_progress', 'track_iter_progress', 'track_parallel_progress', + 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', + 'digit_version', 'get_git_hash', 'import_modules_from_strings', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', + 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', + 'is_method_overridden', 'has_method' + ] +else: + from .env import collect_env + from .logging import get_logger, print_log + from .parrots_jit import jit, skip_no_elena + from .parrots_wrapper import ( + TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader, + PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, + _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, + _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home) + from .registry import Registry, build_from_cfg + from .trace import is_jit_tracing + __all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', + 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', + 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list', + 'check_prerequisites', 'requires_package', 'requires_executable', + 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', + 'symlink', 'scandir', 'ProgressBar', 'track_progress', + 'track_iter_progress', 'track_parallel_progress', 'Registry', + 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm', + '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm', + '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd', + 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', + 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', + 'deprecated_api_warning', 'digit_version', 'get_git_hash', + 'import_modules_from_strings', 'jit', 'skip_no_elena', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', + 'assert_params_all_zeros', 'check_python_script', + 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', + '_get_cuda_home', 'has_method' + ] diff --git a/annotator/uniformer/mmcv/utils/config.py b/annotator/uniformer/mmcv/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..17149353aefac6d737c67bb2f35a3a6cd2147b0a --- /dev/null +++ b/annotator/uniformer/mmcv/utils/config.py @@ -0,0 +1,688 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import uuid +import warnings +from argparse import Action, ArgumentParser +from collections import abc +from importlib import import_module + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +from .misc import import_modules_from_strings +from .path import check_file_exist + +if platform.system() == 'Windows': + import regex as re +else: + import re + +BASE_KEY = '_base_' +DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' +RESERVED_KEYS = ['filename', 'text', 'pretty_text'] + + +class ConfigDict(Dict): + + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " + f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +def add_args(parser, cfg, prefix=''): + for k, v in cfg.items(): + if isinstance(v, str): + parser.add_argument('--' + prefix + k) + elif isinstance(v, int): + parser.add_argument('--' + prefix + k, type=int) + elif isinstance(v, float): + parser.add_argument('--' + prefix + k, type=float) + elif isinstance(v, bool): + parser.add_argument('--' + prefix + k, action='store_true') + elif isinstance(v, dict): + add_args(parser, v, prefix + k + '.') + elif isinstance(v, abc.Iterable): + parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') + else: + print(f'cannot parse key {prefix + k} of type {type(v)}') + return parser + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') + + @staticmethod + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname) + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' + value = value.replace('\\', '/') + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' + base_var_dict[randstr] = base_var + regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split('.'): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars( + v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg) + elif isinstance(cfg, list): + cfg = [ + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg + ] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split('.'): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname) + if platform.system() == 'Windows': + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, + temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name) + + if filename.endswith('.py'): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith('__') + } + # delete imported module + del sys.modules[temp_module_name] + elif filename.endswith(('.yml', '.yaml', '.json')): + import annotator.uniformer.mmcv as mmcv + cfg_dict = mmcv.load(temp_config_file.name) + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = f'The config file {filename} will be deprecated ' \ + 'in the future.' + if 'expected' in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' \ + 'instead.' + if 'reference' in deprecation_info: + warning_msg += ' More information can be found at ' \ + f'{deprecation_info["reference"]}' + warnings.warn(warning_msg) + + cfg_text = filename + '\n' + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = base_filename if isinstance( + base_filename, list) else [base_filename] + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError('Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, + base_cfg_dict) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = '\n'.join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f'Index {k} exceeds the length of list {b}') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, + dict) and k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f'{k}={v} in child config cannot inherit from base ' + f'because {k} is a dict in the child config but is of ' + f'type {type(b[k])} in base config. You may set ' + f'`{DELETE_KEY}=True` to ignore the base config') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = v + return b + + @staticmethod + def fromfile(filename, + use_predefined_variables=True, + import_custom_modules=True): + cfg_dict, cfg_text = Config._file2dict(filename, + use_predefined_variables) + if import_custom_modules and cfg_dict.get('custom_imports', None): + import_modules_from_strings(**cfg_dict['custom_imports']) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + obj:`Config`: Config obj. + """ + if file_format not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + if file_format != '.py' and 'dict(' in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn( + 'Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix=file_format, + delete=False) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.fromfile(temp_file.name) + os.remove(temp_file.name) + return cfg + + @staticmethod + def auto_argparser(description=None): + """Generate argparser from config file automatically (experimental)""" + partial_parser = ArgumentParser(description=description) + partial_parser.add_argument('config', help='config file path') + cfg_file = partial_parser.parse_known_args()[0].config + cfg = Config.fromfile(cfg_file) + parser = ArgumentParser(description=description) + parser.add_argument('config', help='config file path') + add_args(parser, cfg) + return parser, cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError('cfg_dict must be a dict, but ' + f'got {type(cfg_dict)}') + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f'{key} is reserved for config file') + + super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) + super(Config, self).__setattr__('_filename', filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, 'r') as f: + text = f.read() + else: + text = '' + super(Config, self).__setattr__('_text', text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = '[\n' + v_str += '\n'.join( + f'dict({_indent(_format_dict(v_), indent)}),' + for v_ in v).rstrip(',') + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + ']' + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__('_cfg_dict', _cfg_dict) + super(Config, self).__setattr__('_filename', _filename) + super(Config, self).__setattr__('_text', _text) + + def dump(self, file=None): + cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() + if self.filename.endswith('.py'): + if file is None: + return self.pretty_text + else: + with open(file, 'w', encoding='utf-8') as f: + f.write(self.pretty_text) + else: + import annotator.uniformer.mmcv as mmcv + if file is None: + file_format = self.filename.split('.')[-1] + return mmcv.dump(cfg_dict, file_format=file_format) + else: + mmcv.dump(cfg_dict, file) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'model.backbone.depth': 50, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(depth=50, with_cp=True))) + + # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) + >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split('.') + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + super(Config, self).__setattr__( + '_cfg_dict', + Config._merge_a_into_b( + option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ['true', 'false']: + return True if val.lower() == 'true' else False + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + + All elements inside '()' or '[]' are treated as iterable values. + + Args: + val (str): Value string. + + Returns: + list | tuple: The expanded list or tuple from the string. + + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b'], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count('(') == string.count(')')) and ( + string.count('[') == string.count(']')), \ + f'Imbalanced brackets exist in {string}' + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if ((char == ',') and (pre.count('(') == pre.count(')')) + and (pre.count('[') == pre.count(']'))): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip('\'\"').replace(' ', '') + is_tuple = False + if val.startswith('(') and val.endswith(')'): + is_tuple = True + val = val[1:-1] + elif val.startswith('[') and val.endswith(']'): + val = val[1:-1] + elif ',' not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1:] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split('=', maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) diff --git a/annotator/uniformer/mmcv/utils/env.py b/annotator/uniformer/mmcv/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f0d92529e193e6d8339419bcd9bed7901a7769 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/env.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file holding some environment constant for sharing by other files.""" + +import os.path as osp +import subprocess +import sys +from collections import defaultdict + +import cv2 +import torch + +import annotator.uniformer.mmcv as mmcv +from .parrots_wrapper import get_build_config + + +def collect_env(): + """Collect the information of the running environments. + + Returns: + dict: The environment information. The following fields are contained. + + - sys.platform: The variable of ``sys.platform``. + - Python: Python version. + - CUDA available: Bool, indicating if CUDA is available. + - GPU devices: Device type of each GPU. + - CUDA_HOME (optional): The env var ``CUDA_HOME``. + - NVCC (optional): NVCC version. + - GCC: GCC version, "n/a" if GCC is not installed. + - PyTorch: PyTorch version. + - PyTorch compiling details: The output of \ + ``torch.__config__.show()``. + - TorchVision (optional): TorchVision version. + - OpenCV: OpenCV version. + - MMCV: MMCV version. + - MMCV Compiler: The GCC version for compiling MMCV ops. + - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops. + """ + env_info = {} + env_info['sys.platform'] = sys.platform + env_info['Python'] = sys.version.replace('\n', '') + + cuda_available = torch.cuda.is_available() + env_info['CUDA available'] = cuda_available + + if cuda_available: + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + for name, device_ids in devices.items(): + env_info['GPU ' + ','.join(device_ids)] = name + + from annotator.uniformer.mmcv.utils.parrots_wrapper import _get_cuda_home + CUDA_HOME = _get_cuda_home() + env_info['CUDA_HOME'] = CUDA_HOME + + if CUDA_HOME is not None and osp.isdir(CUDA_HOME): + try: + nvcc = osp.join(CUDA_HOME, 'bin/nvcc') + nvcc = subprocess.check_output( + f'"{nvcc}" -V | tail -n1', shell=True) + nvcc = nvcc.decode('utf-8').strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' + env_info['NVCC'] = nvcc + + try: + gcc = subprocess.check_output('gcc --version | head -n1', shell=True) + gcc = gcc.decode('utf-8').strip() + env_info['GCC'] = gcc + except subprocess.CalledProcessError: # gcc is unavailable + env_info['GCC'] = 'n/a' + + env_info['PyTorch'] = torch.__version__ + env_info['PyTorch compiling details'] = get_build_config() + + try: + import torchvision + env_info['TorchVision'] = torchvision.__version__ + except ModuleNotFoundError: + pass + + env_info['OpenCV'] = cv2.__version__ + + env_info['MMCV'] = mmcv.__version__ + + try: + from annotator.uniformer.mmcv.ops import get_compiler_version, get_compiling_cuda_version + except ModuleNotFoundError: + env_info['MMCV Compiler'] = 'n/a' + env_info['MMCV CUDA Compiler'] = 'n/a' + else: + env_info['MMCV Compiler'] = get_compiler_version() + env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version() + + return env_info diff --git a/annotator/uniformer/mmcv/utils/ext_loader.py b/annotator/uniformer/mmcv/utils/ext_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..08132d2c1b9a1c28880e4bab4d4fa1ba39d9d083 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/ext_loader.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os +import pkgutil +import warnings +from collections import namedtuple + +import torch + +if torch.__version__ != 'parrots': + + def load_ext(name, funcs): + ext = importlib.import_module('mmcv.' + name) + for fun in funcs: + assert hasattr(ext, fun), f'{fun} miss in module {name}' + return ext +else: + from parrots import extension + from parrots.base import ParrotsException + + has_return_value_ops = [ + 'nms', + 'softnms', + 'nms_match', + 'nms_rotated', + 'top_pool_forward', + 'top_pool_backward', + 'bottom_pool_forward', + 'bottom_pool_backward', + 'left_pool_forward', + 'left_pool_backward', + 'right_pool_forward', + 'right_pool_backward', + 'fused_bias_leakyrelu', + 'upfirdn2d', + 'ms_deform_attn_forward', + 'pixel_group', + 'contour_expand', + ] + + def get_fake_func(name, e): + + def fake_func(*args, **kwargs): + warnings.warn(f'{name} is not supported in parrots now') + raise e + + return fake_func + + def load_ext(name, funcs): + ExtModule = namedtuple('ExtModule', funcs) + ext_list = [] + lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + for fun in funcs: + try: + ext_fun = extension.load(fun, name, lib_dir=lib_root) + except ParrotsException as e: + if 'No element registered' not in e.message: + warnings.warn(e.message) + ext_fun = get_fake_func(fun, e) + ext_list.append(ext_fun) + else: + if fun in has_return_value_ops: + ext_list.append(ext_fun.op) + else: + ext_list.append(ext_fun.op_) + return ExtModule(*ext_list) + + +def check_ops_exist(): + ext_loader = pkgutil.find_loader('mmcv._ext') + return ext_loader is not None diff --git a/annotator/uniformer/mmcv/utils/logging.py b/annotator/uniformer/mmcv/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/logging.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.distributed as dist + +logger_initialized = {} + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) + # to the root logger. As logger.propagate is True by default, this root + # level handler causes logging messages from rank>0 processes to + # unexpectedly show up on the console, creating much unwanted clutter. + # To fix this issue, we set the root logger's StreamHandler, if any, to log + # at the ERROR level. + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == 'silent': + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + 'logger should be either a logging.Logger object, str, ' + f'"silent" or None, but got {type(logger)}') diff --git a/annotator/uniformer/mmcv/utils/misc.py b/annotator/uniformer/mmcv/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c --- /dev/null +++ b/annotator/uniformer/mmcv/utils/misc.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections.abc +import functools +import itertools +import subprocess +import warnings +from collections import abc +from importlib import import_module +from inspect import getfullargspec +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError( + f'custom_imports must be a list but got type {type(imports)}') + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError( + f'{imp} is of type {type(imp)} and cannot be imported.') + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f'{imp} failed to import and is ignored.', + UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def iter_cast(inputs, dst_type, return_type=None): + """Cast elements of an iterable object into some type. + + Args: + inputs (Iterable): The input object. + dst_type (type): Destination type. + return_type (type, optional): If specified, the output object will be + converted to this type, otherwise an iterator. + + Returns: + iterator or specified type: The converted object. + """ + if not isinstance(inputs, abc.Iterable): + raise TypeError('inputs must be an iterable object') + if not isinstance(dst_type, type): + raise TypeError('"dst_type" must be a valid type') + + out_iterable = map(dst_type, inputs) + + if return_type is None: + return out_iterable + else: + return return_type(out_iterable) + + +def list_cast(inputs, dst_type): + """Cast elements of an iterable object into a list of some type. + + A partial method of :func:`iter_cast`. + """ + return iter_cast(inputs, dst_type, return_type=list) + + +def tuple_cast(inputs, dst_type): + """Cast elements of an iterable object into a tuple of some type. + + A partial method of :func:`iter_cast`. + """ + return iter_cast(inputs, dst_type, return_type=tuple) + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_list_of(seq, expected_type): + """Check whether it is a list of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=list) + + +def is_tuple_of(seq, expected_type): + """Check whether it is a tuple of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=tuple) + + +def slice_list(in_list, lens): + """Slice a list into several sub lists by a list of given length. + + Args: + in_list (list): The list to be sliced. + lens(int or list): The expected length of each out list. + + Returns: + list: A list of sliced list. + """ + if isinstance(lens, int): + assert len(in_list) % lens == 0 + lens = [lens] * int(len(in_list) / lens) + if not isinstance(lens, list): + raise TypeError('"indices" must be an integer or a list of integers') + elif sum(lens) != len(in_list): + raise ValueError('sum of lens and list length does not ' + f'match: {sum(lens)} != {len(in_list)}') + out_list = [] + idx = 0 + for i in range(len(lens)): + out_list.append(in_list[idx:idx + lens[i]]) + idx += lens[i] + return out_list + + +def concat_list(in_list): + """Concatenate a list of list into a single list. + + Args: + in_list (list): The list of list to be merged. + + Returns: + list: The concatenated flat list. + """ + return list(itertools.chain(*in_list)) + + +def check_prerequisites( + prerequisites, + checker, + msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' + 'found, please install them first.'): # yapf: disable + """A decorator factory to check if prerequisites are satisfied. + + Args: + prerequisites (str of list[str]): Prerequisites to be checked. + checker (callable): The checker method that returns True if a + prerequisite is meet, False otherwise. + msg_tmpl (str): The message template with two variables. + + Returns: + decorator: A specific decorator. + """ + + def wrap(func): + + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + requirements = [prerequisites] if isinstance( + prerequisites, str) else prerequisites + missing = [] + for item in requirements: + if not checker(item): + missing.append(item) + if missing: + print(msg_tmpl.format(', '.join(missing), func.__name__)) + raise RuntimeError('Prerequisites not meet.') + else: + return func(*args, **kwargs) + + return wrapped_func + + return wrap + + +def _check_py_package(package): + try: + import_module(package) + except ImportError: + return False + else: + return True + + +def _check_executable(cmd): + if subprocess.call(f'which {cmd}', shell=True) != 0: + return False + else: + return True + + +def requires_package(prerequisites): + """A decorator to check if some python packages are installed. + + Example: + >>> @requires_package('numpy') + >>> func(arg1, args): + >>> return numpy.zeros(1) + array([0.]) + >>> @requires_package(['numpy', 'non_package']) + >>> func(arg1, args): + >>> return numpy.zeros(1) + ImportError + """ + return check_prerequisites(prerequisites, checker=_check_py_package) + + +def requires_executable(prerequisites): + """A decorator to check if some executable files are installed. + + Example: + >>> @requires_executable('ffmpeg') + >>> func(arg1, args): + >>> print(1) + 1 + """ + return check_prerequisites(prerequisites, checker=_check_executable) + + +def deprecated_api_warning(name_dict, cls_name=None): + """A decorator to check if some arguments are deprecate and try to replace + deprecate src_arg_name to dst_arg_name. + + Args: + name_dict(dict): + key (str): Deprecate argument names. + val (str): Expected argument names. + + Returns: + func: New function. + """ + + def api_warning_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get name of the function + func_name = old_func.__name__ + if cls_name is not None: + func_name = f'{cls_name}.{func_name}' + if args: + arg_names = args_info.args[:len(args)] + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in arg_names: + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead') + arg_names[arg_names.index(src_arg_name)] = dst_arg_name + if kwargs: + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in kwargs: + + assert dst_arg_name not in kwargs, ( + f'The expected behavior is to replace ' + f'the deprecated key `{src_arg_name}` to ' + f'new key `{dst_arg_name}`, but got them ' + f'in the arguments at the same time, which ' + f'is confusing. `{src_arg_name} will be ' + f'deprecated in the future, please ' + f'use `{dst_arg_name}` instead.') + + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead') + kwargs[dst_arg_name] = kwargs.pop(src_arg_name) + + # apply converted arguments to the decorated method + output = old_func(*args, **kwargs) + return output + + return new_func + + return api_warning_wrapper + + +def is_method_overridden(method, base_class, derived_class): + """Check if a method of base class is overridden in derived class. + + Args: + method (str): the method name to check. + base_class (type): the class of the base class. + derived_class (type | Any): the class or instance of the derived class. + """ + assert isinstance(base_class, type), \ + "base_class doesn't accept instance, Please pass class instead." + + if not isinstance(derived_class, type): + derived_class = derived_class.__class__ + + base_method = getattr(base_class, method) + derived_method = getattr(derived_class, method) + return derived_method != base_method + + +def has_method(obj: object, method: str) -> bool: + """Check whether the object has a method. + + Args: + method (str): The method name to check. + obj (object): The object to check. + + Returns: + bool: True if the object has the method else False. + """ + return hasattr(obj, method) and callable(getattr(obj, method)) diff --git a/annotator/uniformer/mmcv/utils/parrots_jit.py b/annotator/uniformer/mmcv/utils/parrots_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e --- /dev/null +++ b/annotator/uniformer/mmcv/utils/parrots_jit.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +from .parrots_wrapper import TORCH_VERSION + +parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') + +if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': + from parrots.jit import pat as jit +else: + + def jit(func=None, + check_input=None, + full_shape=True, + derivate=False, + coderize=False, + optimize=False): + + def wrapper(func): + + def wrapper_inner(*args, **kargs): + return func(*args, **kargs) + + return wrapper_inner + + if func is None: + return wrapper + else: + return func + + +if TORCH_VERSION == 'parrots': + from parrots.utils.tester import skip_no_elena +else: + + def skip_no_elena(func): + + def wrapper(*args, **kargs): + return func(*args, **kargs) + + return wrapper diff --git a/annotator/uniformer/mmcv/utils/parrots_wrapper.py b/annotator/uniformer/mmcv/utils/parrots_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf --- /dev/null +++ b/annotator/uniformer/mmcv/utils/parrots_wrapper.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import torch + +TORCH_VERSION = torch.__version__ + + +def is_rocm_pytorch() -> bool: + is_rocm = False + if TORCH_VERSION != 'parrots': + try: + from torch.utils.cpp_extension import ROCM_HOME + is_rocm = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + except ImportError: + pass + return is_rocm + + +def _get_cuda_home(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import CUDA_HOME + else: + if is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + CUDA_HOME = ROCM_HOME + else: + from torch.utils.cpp_extension import CUDA_HOME + return CUDA_HOME + + +def get_build_config(): + if TORCH_VERSION == 'parrots': + from parrots.config import get_build_info + return get_build_info() + else: + return torch.__config__.show() + + +def _get_conv(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin + else: + from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin + return _ConvNd, _ConvTransposeMixin + + +def _get_dataloader(): + if TORCH_VERSION == 'parrots': + from torch.utils.data import DataLoader, PoolDataLoader + else: + from torch.utils.data import DataLoader + PoolDataLoader = DataLoader + return DataLoader, PoolDataLoader + + +def _get_extension(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import BuildExtension, Extension + CppExtension = partial(Extension, cuda=False) + CUDAExtension = partial(Extension, cuda=True) + else: + from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) + return BuildExtension, CppExtension, CUDAExtension + + +def _get_pool(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + else: + from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd + + +def _get_norm(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm2d + else: + from torch.nn.modules.instancenorm import _InstanceNorm + from torch.nn.modules.batchnorm import _BatchNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm + return _BatchNorm, _InstanceNorm, SyncBatchNorm_ + + +_ConvNd, _ConvTransposeMixin = _get_conv() +DataLoader, PoolDataLoader = _get_dataloader() +BuildExtension, CppExtension, CUDAExtension = _get_extension() +_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() +_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() + + +class SyncBatchNorm(SyncBatchNorm_): + + def _check_input_dim(self, input): + if TORCH_VERSION == 'parrots': + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input (got {input.dim()}D input)') + else: + super()._check_input_dim(input) diff --git a/annotator/uniformer/mmcv/utils/path.py b/annotator/uniformer/mmcv/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/path.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from pathlib import Path + +from .misc import is_str + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError('`filepath` should be a string or a Path') + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == '': + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple( + item.lower() for item in suffix) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, + case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_vcs_root(path, markers=('.git', )): + """Finds the root directory (including itself) of specified markers. + + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None diff --git a/annotator/uniformer/mmcv/utils/progressbar.py b/annotator/uniformer/mmcv/utils/progressbar.py new file mode 100644 index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/progressbar.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from collections.abc import Iterable +from multiprocessing import Pool +from shutil import get_terminal_size + +from .timer import Timer + + +class ProgressBar: + """A progress bar which can print the progress.""" + + def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): + self.task_num = task_num + self.bar_width = bar_width + self.completed = 0 + self.file = file + if start: + self.start() + + @property + def terminal_width(self): + width, _ = get_terminal_size() + return width + + def start(self): + if self.task_num > 0: + self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' + 'elapsed: 0s, ETA:') + else: + self.file.write('completed: 0, elapsed: 0s') + self.file.flush() + self.timer = Timer() + + def update(self, num_tasks=1): + assert num_tasks > 0 + self.completed += num_tasks + elapsed = self.timer.since_start() + if elapsed > 0: + fps = self.completed / elapsed + else: + fps = float('inf') + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ + f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ + f'ETA: {eta:5}s' + + bar_width = min(self.bar_width, + int(self.terminal_width - len(msg)) + 2, + int(self.terminal_width * 0.6)) + bar_width = max(2, bar_width) + mark_width = int(bar_width * percentage) + bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) + self.file.write(msg.format(bar_chars)) + else: + self.file.write( + f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' + f' {fps:.1f} tasks/s') + self.file.flush() + + +def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): + """Track the progress of tasks execution with a progress bar. + + Tasks are done with a simple for-loop. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + results = [] + for task in tasks: + results.append(func(task, **kwargs)) + prog_bar.update() + prog_bar.file.write('\n') + return results + + +def init_pool(process_num, initializer=None, initargs=None): + if initializer is None: + return Pool(process_num) + elif initargs is None: + return Pool(process_num, initializer) + else: + if not isinstance(initargs, tuple): + raise TypeError('"initargs" must be a tuple') + return Pool(process_num, initializer, initargs) + + +def track_parallel_progress(func, + tasks, + nproc, + initializer=None, + initargs=None, + bar_width=50, + chunksize=1, + skip_first=False, + keep_order=True, + file=sys.stdout): + """Track the progress of parallel task execution with a progress bar. + + The built-in :mod:`multiprocessing` module is used for process pools and + tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + nproc (int): Process (worker) number. + initializer (None or callable): Refer to :class:`multiprocessing.Pool` + for details. + initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for + details. + chunksize (int): Refer to :class:`multiprocessing.Pool` for details. + bar_width (int): Width of progress bar. + skip_first (bool): Whether to skip the first sample for each worker + when estimating fps, since the initialization step may takes + longer. + keep_order (bool): If True, :func:`Pool.imap` is used, otherwise + :func:`Pool.imap_unordered` is used. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + pool = init_pool(nproc, initializer, initargs) + start = not skip_first + task_num -= nproc * chunksize * int(skip_first) + prog_bar = ProgressBar(task_num, bar_width, start, file=file) + results = [] + if keep_order: + gen = pool.imap(func, tasks, chunksize) + else: + gen = pool.imap_unordered(func, tasks, chunksize) + for result in gen: + results.append(result) + if skip_first: + if len(results) < nproc * chunksize: + continue + elif len(results) == nproc * chunksize: + prog_bar.start() + continue + prog_bar.update() + prog_bar.file.write('\n') + pool.close() + pool.join() + return results + + +def track_iter_progress(tasks, bar_width=50, file=sys.stdout): + """Track the progress of tasks iteration or enumeration with a progress + bar. + + Tasks are yielded with a simple for-loop. + + Args: + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Yields: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + for task in tasks: + yield task + prog_bar.update() + prog_bar.file.write('\n') diff --git a/annotator/uniformer/mmcv/utils/registry.py b/annotator/uniformer/mmcv/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/registry.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from functools import partial + +from .misc import is_seq_of + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from config dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an mmcv.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry') + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = self.__class__.__name__ + \ + f'(name={self._name}, ' \ + f'items={self._module_dict})' + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split('.') + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find('.') + if split_index != -1: + return key[:split_index], key[split_index + 1:] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, \ + f'scope {registry.scope} exists in {self.name} registry' + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError('module must be a class, ' + f'but got {type(module_class)}') + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + 'The old API of register_module(module, force=False) ' + 'is deprecated and will be removed, please use the new API ' + 'register_module(name=None, force=False, module=None) instead.') + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + 'name must be either of None, an instance of str or a sequence' + f' of str, but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module( + module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module( + module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/annotator/uniformer/mmcv/utils/testing.py b/annotator/uniformer/mmcv/utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/testing.py @@ -0,0 +1,140 @@ +# Copyright (c) Open-MMLab. +import sys +from collections.abc import Iterable +from runpy import run_path +from shlex import split +from typing import Any, Dict, List +from unittest.mock import patch + + +def check_python_script(cmd): + """Run the python cmd script with `__main__`. The difference between + `os.system` is that, this function exectues code in the current process, so + that it can be tracked by coverage tools. Currently it supports two forms: + + - ./tests/data/scripts/hello.py zz + - python tests/data/scripts/hello.py zz + """ + args = split(cmd) + if args[0] == 'python': + args = args[1:] + with patch.object(sys, 'argv', args): + run_path(args[0], run_name='__main__') + + +def _any(judge_result): + """Since built-in ``any`` works only when the element of iterable is not + iterable, implement the function.""" + if not isinstance(judge_result, Iterable): + return judge_result + + try: + for element in judge_result: + if _any(element): + return True + except TypeError: + # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) + if judge_result: + return True + return False + + +def assert_dict_contains_subset(dict_obj: Dict[Any, Any], + expected_subset: Dict[Any, Any]) -> bool: + """Check if the dict_obj contains the expected_subset. + + Args: + dict_obj (Dict[Any, Any]): Dict object to be checked. + expected_subset (Dict[Any, Any]): Subset expected to be contained in + dict_obj. + + Returns: + bool: Whether the dict_obj contains the expected_subset. + """ + + for key, value in expected_subset.items(): + if key not in dict_obj.keys() or _any(dict_obj[key] != value): + return False + return True + + +def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: + """Check if attribute of class object is correct. + + Args: + obj (object): Class object to be checked. + expected_attrs (Dict[str, Any]): Dict of the expected attrs. + + Returns: + bool: Whether the attribute of class object is correct. + """ + for attr, value in expected_attrs.items(): + if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): + return False + return True + + +def assert_dict_has_keys(obj: Dict[str, Any], + expected_keys: List[str]) -> bool: + """Check if the obj has all the expected_keys. + + Args: + obj (Dict[str, Any]): Object to be checked. + expected_keys (List[str]): Keys expected to contained in the keys of + the obj. + + Returns: + bool: Whether the obj has the expected keys. + """ + return set(expected_keys).issubset(set(obj.keys())) + + +def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: + """Check if target_keys is equal to result_keys. + + Args: + result_keys (List[str]): Result keys to be checked. + target_keys (List[str]): Target keys to be checked. + + Returns: + bool: Whether target_keys is equal to result_keys. + """ + return set(result_keys) == set(target_keys) + + +def assert_is_norm_layer(module) -> bool: + """Check if the module is a norm layer. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the module is a norm layer. + """ + from .parrots_wrapper import _BatchNorm, _InstanceNorm + from torch.nn import GroupNorm, LayerNorm + norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) + return isinstance(module, norm_layer_candidates) + + +def assert_params_all_zeros(module) -> bool: + """Check if the parameters of the module is all zeros. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the parameters of the module is all zeros. + """ + weight_data = module.weight.data + is_weight_zero = weight_data.allclose( + weight_data.new_zeros(weight_data.size())) + + if hasattr(module, 'bias') and module.bias is not None: + bias_data = module.bias.data + is_bias_zero = bias_data.allclose( + bias_data.new_zeros(bias_data.size())) + else: + is_bias_zero = True + + return is_weight_zero and is_bias_zero diff --git a/annotator/uniformer/mmcv/utils/timer.py b/annotator/uniformer/mmcv/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..e3db7d497d8b374e18b5297e0a1d6eb186fd8cba --- /dev/null +++ b/annotator/uniformer/mmcv/utils/timer.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from time import time + + +class TimerError(Exception): + + def __init__(self, message): + self.message = message + super(TimerError, self).__init__(message) + + +class Timer: + """A flexible Timer class. + + :Example: + + >>> import time + >>> import annotator.uniformer.mmcv as mmcv + >>> with mmcv.Timer(): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + 1.000 + >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + it takes 1.0 seconds + >>> timer = mmcv.Timer() + >>> time.sleep(0.5) + >>> print(timer.since_start()) + 0.500 + >>> time.sleep(0.5) + >>> print(timer.since_last_check()) + 0.500 + >>> print(timer.since_start()) + 1.000 + """ + + def __init__(self, start=True, print_tmpl=None): + self._is_running = False + self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' + if start: + self.start() + + @property + def is_running(self): + """bool: indicate whether the timer is running""" + return self._is_running + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + print(self.print_tmpl.format(self.since_last_check())) + self._is_running = False + + def start(self): + """Start the timer.""" + if not self._is_running: + self._t_start = time() + self._is_running = True + self._t_last = time() + + def since_start(self): + """Total time since the timer is started. + + Returns (float): Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + self._t_last = time() + return self._t_last - self._t_start + + def since_last_check(self): + """Time since the last checking. + + Either :func:`since_start` or :func:`since_last_check` is a checking + operation. + + Returns (float): Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + dur = time() - self._t_last + self._t_last = time() + return dur + + +_g_timers = {} # global timers + + +def check_time(timer_id): + """Add check points in a single line. + + This method is suitable for running a task on a list of items. A timer will + be registered when the method is called for the first time. + + :Example: + + >>> import time + >>> import annotator.uniformer.mmcv as mmcv + >>> for i in range(1, 6): + >>> # simulate a code block + >>> time.sleep(i) + >>> mmcv.check_time('task1') + 2.000 + 3.000 + 4.000 + 5.000 + + Args: + timer_id (str): Timer identifier. + """ + if timer_id not in _g_timers: + _g_timers[timer_id] = Timer() + return 0 + else: + return _g_timers[timer_id].since_last_check() diff --git a/annotator/uniformer/mmcv/utils/trace.py b/annotator/uniformer/mmcv/utils/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca99dc3eda05ef980d9a4249b50deca8273b6cc --- /dev/null +++ b/annotator/uniformer/mmcv/utils/trace.py @@ -0,0 +1,23 @@ +import warnings + +import torch + +from annotator.uniformer.mmcv.utils import digit_version + + +def is_jit_tracing() -> bool: + if (torch.__version__ != 'parrots' + and digit_version(torch.__version__) >= digit_version('1.6.0')): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False diff --git a/annotator/uniformer/mmcv/utils/version_utils.py b/annotator/uniformer/mmcv/utils/version_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/version_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import subprocess +import warnings + +from packaging.version import parse + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert 'parrots' not in version_str + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + +def get_git_hash(fallback='unknown', digits=None): + """Get the git hash of the current repo. + + Args: + fallback (str, optional): The fallback string when git hash is + unavailable. Defaults to 'unknown'. + digits (int, optional): kept digits of the hash. Defaults to None, + meaning all digits are kept. + + Returns: + str: Git commit hash. + """ + + if digits is not None and not isinstance(digits, int): + raise TypeError('digits must be None or an integer') + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + if digits is not None: + sha = sha[:digits] + except OSError: + sha = fallback + + return sha diff --git a/annotator/uniformer/mmcv/version.py b/annotator/uniformer/mmcv/version.py new file mode 100644 index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0 --- /dev/null +++ b/annotator/uniformer/mmcv/version.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +__version__ = '1.3.17' + + +def parse_version_info(version_str: str, length: int = 4) -> tuple: + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into + (2, 0, 0, 0, 'rc', 1) (when length is set to 4). + """ + from packaging.version import parse + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + release.extend(list(version.pre)) + elif version.is_postrelease: + release.extend(list(version.post)) + else: + release.extend([0, 0]) + return tuple(release) + + +version_info = tuple(int(x) for x in __version__.split('.')[:3]) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/annotator/uniformer/mmcv/video/__init__.py b/annotator/uniformer/mmcv/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb --- /dev/null +++ b/annotator/uniformer/mmcv/video/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .io import Cache, VideoReader, frames2video +from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread, + flowwrite, quantize_flow, sparse_flow_from_bytes) +from .processing import concat_video, convert_video, cut_video, resize_video + +__all__ = [ + 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video', + 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow', + 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes' +] diff --git a/annotator/uniformer/mmcv/video/io.py b/annotator/uniformer/mmcv/video/io.py new file mode 100644 index 0000000000000000000000000000000000000000..9879154227f640c262853b92c219461c6f67ee8e --- /dev/null +++ b/annotator/uniformer/mmcv/video/io.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict + +import cv2 +from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT, + CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH, + CAP_PROP_POS_FRAMES, VideoWriter_fourcc) + +from annotator.uniformer.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir, + track_progress) + + +class Cache: + + def __init__(self, capacity): + self._cache = OrderedDict() + self._capacity = int(capacity) + if capacity <= 0: + raise ValueError('capacity must be a positive integer') + + @property + def capacity(self): + return self._capacity + + @property + def size(self): + return len(self._cache) + + def put(self, key, val): + if key in self._cache: + return + if len(self._cache) >= self.capacity: + self._cache.popitem(last=False) + self._cache[key] = val + + def get(self, key, default=None): + val = self._cache[key] if key in self._cache else default + return val + + +class VideoReader: + """Video class with similar usage to a list object. + + This video warpper class provides convenient apis to access frames. + There exists an issue of OpenCV's VideoCapture class that jumping to a + certain frame may be inaccurate. It is fixed in this class by checking + the position after jumping each time. + Cache is used when decoding videos. So if the same frame is visited for + the second time, there is no need to decode again if it is stored in the + cache. + + :Example: + + >>> import annotator.uniformer.mmcv as mmcv + >>> v = mmcv.VideoReader('sample.mp4') + >>> len(v) # get the total frame number with `len()` + 120 + >>> for img in v: # v is iterable + >>> mmcv.imshow(img) + >>> v[5] # get the 6th frame + """ + + def __init__(self, filename, cache_capacity=10): + # Check whether the video path is a url + if not filename.startswith(('https://', 'http://')): + check_file_exist(filename, 'Video file not found: ' + filename) + self._vcap = cv2.VideoCapture(filename) + assert cache_capacity > 0 + self._cache = Cache(cache_capacity) + self._position = 0 + # get basic info + self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH)) + self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT)) + self._fps = self._vcap.get(CAP_PROP_FPS) + self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT)) + self._fourcc = self._vcap.get(CAP_PROP_FOURCC) + + @property + def vcap(self): + """:obj:`cv2.VideoCapture`: The raw VideoCapture object.""" + return self._vcap + + @property + def opened(self): + """bool: Indicate whether the video is opened.""" + return self._vcap.isOpened() + + @property + def width(self): + """int: Width of video frames.""" + return self._width + + @property + def height(self): + """int: Height of video frames.""" + return self._height + + @property + def resolution(self): + """tuple: Video resolution (width, height).""" + return (self._width, self._height) + + @property + def fps(self): + """float: FPS of the video.""" + return self._fps + + @property + def frame_cnt(self): + """int: Total frames of the video.""" + return self._frame_cnt + + @property + def fourcc(self): + """str: "Four character code" of the video.""" + return self._fourcc + + @property + def position(self): + """int: Current cursor position, indicating frame decoded.""" + return self._position + + def _get_real_position(self): + return int(round(self._vcap.get(CAP_PROP_POS_FRAMES))) + + def _set_real_position(self, frame_id): + self._vcap.set(CAP_PROP_POS_FRAMES, frame_id) + pos = self._get_real_position() + for _ in range(frame_id - pos): + self._vcap.read() + self._position = frame_id + + def read(self): + """Read the next frame. + + If the next frame have been decoded before and in the cache, then + return it directly, otherwise decode, cache and return it. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + # pos = self._position + if self._cache: + img = self._cache.get(self._position) + if img is not None: + ret = True + else: + if self._position != self._get_real_position(): + self._set_real_position(self._position) + ret, img = self._vcap.read() + if ret: + self._cache.put(self._position, img) + else: + ret, img = self._vcap.read() + if ret: + self._position += 1 + return img + + def get_frame(self, frame_id): + """Get frame by index. + + Args: + frame_id (int): Index of the expected frame, 0-based. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + if frame_id < 0 or frame_id >= self._frame_cnt: + raise IndexError( + f'"frame_id" must be between 0 and {self._frame_cnt - 1}') + if frame_id == self._position: + return self.read() + if self._cache: + img = self._cache.get(frame_id) + if img is not None: + self._position = frame_id + 1 + return img + self._set_real_position(frame_id) + ret, img = self._vcap.read() + if ret: + if self._cache: + self._cache.put(self._position, img) + self._position += 1 + return img + + def current_frame(self): + """Get the current frame (frame that is just visited). + + Returns: + ndarray or None: If the video is fresh, return None, otherwise + return the frame. + """ + if self._position == 0: + return None + return self._cache.get(self._position - 1) + + def cvt2frames(self, + frame_dir, + file_start=0, + filename_tmpl='{:06d}.jpg', + start=0, + max_num=0, + show_progress=True): + """Convert a video to frame images. + + Args: + frame_dir (str): Output directory to store all the frame images. + file_start (int): Filenames will start from the specified number. + filename_tmpl (str): Filename template with the index as the + placeholder. + start (int): The starting frame index. + max_num (int): Maximum number of frames to be written. + show_progress (bool): Whether to show a progress bar. + """ + mkdir_or_exist(frame_dir) + if max_num == 0: + task_num = self.frame_cnt - start + else: + task_num = min(self.frame_cnt - start, max_num) + if task_num <= 0: + raise ValueError('start must be less than total frame number') + if start > 0: + self._set_real_position(start) + + def write_frame(file_idx): + img = self.read() + if img is None: + return + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + cv2.imwrite(filename, img) + + if show_progress: + track_progress(write_frame, range(file_start, + file_start + task_num)) + else: + for i in range(task_num): + write_frame(file_start + i) + + def __len__(self): + return self.frame_cnt + + def __getitem__(self, index): + if isinstance(index, slice): + return [ + self.get_frame(i) + for i in range(*index.indices(self.frame_cnt)) + ] + # support negative indexing + if index < 0: + index += self.frame_cnt + if index < 0: + raise IndexError('index out of range') + return self.get_frame(index) + + def __iter__(self): + self._set_real_position(0) + return self + + def __next__(self): + img = self.read() + if img is not None: + return img + else: + raise StopIteration + + next = __next__ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._vcap.release() + + +def frames2video(frame_dir, + video_file, + fps=30, + fourcc='XVID', + filename_tmpl='{:06d}.jpg', + start=0, + end=0, + show_progress=True): + """Read the frame images from a directory and join them as a video. + + Args: + frame_dir (str): The directory containing video frames. + video_file (str): Output filename. + fps (float): FPS of the output video. + fourcc (str): Fourcc of the output video, this should be compatible + with the output file type. + filename_tmpl (str): Filename template with the index as the variable. + start (int): Starting frame index. + end (int): Ending frame index. + show_progress (bool): Whether to show a progress bar. + """ + if end == 0: + ext = filename_tmpl.split('.')[-1] + end = len([name for name in scandir(frame_dir, ext)]) + first_file = osp.join(frame_dir, filename_tmpl.format(start)) + check_file_exist(first_file, 'The start frame not found: ' + first_file) + img = cv2.imread(first_file) + height, width = img.shape[:2] + resolution = (width, height) + vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps, + resolution) + + def write_frame(file_idx): + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + img = cv2.imread(filename) + vwriter.write(img) + + if show_progress: + track_progress(write_frame, range(start, end)) + else: + for i in range(start, end): + write_frame(i) + vwriter.release() diff --git a/annotator/uniformer/mmcv/video/optflow.py b/annotator/uniformer/mmcv/video/optflow.py new file mode 100644 index 0000000000000000000000000000000000000000..84160f8d6ef9fceb5a2f89e7481593109fc1905d --- /dev/null +++ b/annotator/uniformer/mmcv/video/optflow.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import cv2 +import numpy as np + +from annotator.uniformer.mmcv.arraymisc import dequantize, quantize +from annotator.uniformer.mmcv.image import imread, imwrite +from annotator.uniformer.mmcv.utils import is_str + + +def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_or_path (ndarray or str): A flow map or filepath. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if isinstance(flow_or_path, np.ndarray): + if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2): + raise ValueError(f'Invalid flow with shape {flow_or_path.shape}') + return flow_or_path + elif not is_str(flow_or_path): + raise TypeError(f'"flow_or_path" must be a filename or numpy array, ' + f'not {type(flow_or_path)}') + + if not quantize: + with open(flow_or_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_or_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_or_path}, ' + 'header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + else: + assert concat_axis in [0, 1] + cat_flow = imread(flow_or_path, flag='unchanged') + if cat_flow.ndim != 2: + raise IOError( + f'{flow_or_path} is not a valid quantized flow file, ' + f'its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [ + quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] + ] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): + """Use flow to warp img. + + Args: + img (ndarray, float or uint8): Image to be warped. + flow (ndarray, float): Optical Flow. + filling_value (int): The missing pixels will be set with filling_value. + interpolate_mode (str): bilinear -> Bilinear Interpolation; + nearest -> Nearest Neighbor. + + Returns: + ndarray: Warped image with the same shape of img + """ + warnings.warn('This function is just for prototyping and cannot ' + 'guarantee the computational efficiency.') + assert flow.ndim == 3, 'Flow must be in 3D arrays.' + height = flow.shape[0] + width = flow.shape[1] + channels = img.shape[2] + + output = np.ones( + (height, width, channels), dtype=img.dtype) * filling_value + + grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2) + dx = grid[:, :, 0] + flow[:, :, 1] + dy = grid[:, :, 1] + flow[:, :, 0] + sx = np.floor(dx).astype(int) + sy = np.floor(dy).astype(int) + valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1) + + if interpolate_mode == 'nearest': + output[valid, :] = img[dx[valid].round().astype(int), + dy[valid].round().astype(int), :] + elif interpolate_mode == 'bilinear': + # dirty walkround for integer positions + eps_ = 1e-6 + dx, dy = dx + eps_, dy + eps_ + left_top_ = img[np.floor(dx[valid]).astype(int), + np.floor(dy[valid]).astype(int), :] * ( + np.ceil(dx[valid]) - dx[valid])[:, None] * ( + np.ceil(dy[valid]) - dy[valid])[:, None] + left_down_ = img[np.ceil(dx[valid]).astype(int), + np.floor(dy[valid]).astype(int), :] * ( + dx[valid] - np.floor(dx[valid]))[:, None] * ( + np.ceil(dy[valid]) - dy[valid])[:, None] + right_top_ = img[np.floor(dx[valid]).astype(int), + np.ceil(dy[valid]).astype(int), :] * ( + np.ceil(dx[valid]) - dx[valid])[:, None] * ( + dy[valid] - np.floor(dy[valid]))[:, None] + right_down_ = img[np.ceil(dx[valid]).astype(int), + np.ceil(dy[valid]).astype(int), :] * ( + dx[valid] - np.floor(dx[valid]))[:, None] * ( + dy[valid] - np.floor(dy[valid]))[:, None] + output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_ + else: + raise NotImplementedError( + 'We only support interpolation modes of nearest and bilinear, ' + f'but got {interpolate_mode}.') + return output.astype(img.dtype) + + +def flow_from_bytes(content): + """Read dense optical flow from bytes. + + .. note:: + This load optical flow function works for FlyingChairs, FlyingThings3D, + Sintel, FlyingChairsOcc datasets, but cannot load the data from + ChairsSDHom. + + Args: + content (bytes): Optical flow bytes got from files or other streams. + + Returns: + ndarray: Loaded optical flow with the shape (H, W, 2). + """ + + # header in first 4 bytes + header = content[:4] + if header.decode('utf-8') != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + # width in second 4 bytes + width = np.frombuffer(content[4:], np.int32, 1).squeeze() + # height in third 4 bytes + height = np.frombuffer(content[8:], np.int32, 1).squeeze() + # after first 12 bytes, all bytes are flow + flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape( + (height, width, 2)) + + return flow + + +def sparse_flow_from_bytes(content): + """Read the optical flow in KITTI datasets from bytes. + + This function is modified from RAFT load the `KITTI datasets + `_. + + Args: + content (bytes): Optical flow bytes got from files or other streams. + + Returns: + Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2) + and flow valid mask with the shape (H, W). + """ # nopa + + content = np.frombuffer(content, np.uint8) + flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + # flow shape (H, W, 2) valid shape (H, W) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid diff --git a/annotator/uniformer/mmcv/video/processing.py b/annotator/uniformer/mmcv/video/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..3d90b96e0823d5f116755e7f498d25d17017224a --- /dev/null +++ b/annotator/uniformer/mmcv/video/processing.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import subprocess +import tempfile + +from annotator.uniformer.mmcv.utils import requires_executable + + +@requires_executable('ffmpeg') +def convert_video(in_file, + out_file, + print_cmd=False, + pre_options='', + **kwargs): + """Convert a video with ffmpeg. + + This provides a general api to ffmpeg, the executed command is:: + + `ffmpeg -y -i ` + + Options(kwargs) are mapped to ffmpeg commands with the following rules: + + - key=val: "-key val" + - key=True: "-key" + - key=False: "" + + Args: + in_file (str): Input video filename. + out_file (str): Output video filename. + pre_options (str): Options appears before "-i ". + print_cmd (bool): Whether to print the final ffmpeg command. + """ + options = [] + for k, v in kwargs.items(): + if isinstance(v, bool): + if v: + options.append(f'-{k}') + elif k == 'log_level': + assert v in [ + 'quiet', 'panic', 'fatal', 'error', 'warning', 'info', + 'verbose', 'debug', 'trace' + ] + options.append(f'-loglevel {v}') + else: + options.append(f'-{k} {v}') + cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \ + f'{out_file}' + if print_cmd: + print(cmd) + subprocess.call(cmd, shell=True) + + +@requires_executable('ffmpeg') +def resize_video(in_file, + out_file, + size=None, + ratio=None, + keep_ar=False, + log_level='info', + print_cmd=False): + """Resize a video. + + Args: + in_file (str): Input video filename. + out_file (str): Output video filename. + size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1). + ratio (tuple or float): Expected resize ratio, (2, 0.5) means + (w*2, h*0.5). + keep_ar (bool): Whether to keep original aspect ratio. + log_level (str): Logging level of ffmpeg. + print_cmd (bool): Whether to print the final ffmpeg command. + """ + if size is None and ratio is None: + raise ValueError('expected size or ratio must be specified') + if size is not None and ratio is not None: + raise ValueError('size and ratio cannot be specified at the same time') + options = {'log_level': log_level} + if size: + if not keep_ar: + options['vf'] = f'scale={size[0]}:{size[1]}' + else: + options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \ + 'force_original_aspect_ratio=decrease' + else: + if not isinstance(ratio, tuple): + ratio = (ratio, ratio) + options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"' + convert_video(in_file, out_file, print_cmd, **options) + + +@requires_executable('ffmpeg') +def cut_video(in_file, + out_file, + start=None, + end=None, + vcodec=None, + acodec=None, + log_level='info', + print_cmd=False): + """Cut a clip from a video. + + Args: + in_file (str): Input video filename. + out_file (str): Output video filename. + start (None or float): Start time (in seconds). + end (None or float): End time (in seconds). + vcodec (None or str): Output video codec, None for unchanged. + acodec (None or str): Output audio codec, None for unchanged. + log_level (str): Logging level of ffmpeg. + print_cmd (bool): Whether to print the final ffmpeg command. + """ + options = {'log_level': log_level} + if vcodec is None: + options['vcodec'] = 'copy' + if acodec is None: + options['acodec'] = 'copy' + if start: + options['ss'] = start + else: + start = 0 + if end: + options['t'] = end - start + convert_video(in_file, out_file, print_cmd, **options) + + +@requires_executable('ffmpeg') +def concat_video(video_list, + out_file, + vcodec=None, + acodec=None, + log_level='info', + print_cmd=False): + """Concatenate multiple videos into a single one. + + Args: + video_list (list): A list of video filenames + out_file (str): Output video filename + vcodec (None or str): Output video codec, None for unchanged + acodec (None or str): Output audio codec, None for unchanged + log_level (str): Logging level of ffmpeg. + print_cmd (bool): Whether to print the final ffmpeg command. + """ + tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True) + with open(tmp_filename, 'w') as f: + for filename in video_list: + f.write(f'file {osp.abspath(filename)}\n') + options = {'log_level': log_level} + if vcodec is None: + options['vcodec'] = 'copy' + if acodec is None: + options['acodec'] = 'copy' + convert_video( + tmp_filename, + out_file, + print_cmd, + pre_options='-f concat -safe 0', + **options) + os.close(tmp_filehandler) + os.remove(tmp_filename) diff --git a/annotator/uniformer/mmcv/visualization/__init__.py b/annotator/uniformer/mmcv/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..835df136bdcf69348281d22914d41aa84cdf92b1 --- /dev/null +++ b/annotator/uniformer/mmcv/visualization/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .color import Color, color_val +from .image import imshow, imshow_bboxes, imshow_det_bboxes +from .optflow import flow2rgb, flowshow, make_color_wheel + +__all__ = [ + 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes', + 'flowshow', 'flow2rgb', 'make_color_wheel' +] diff --git a/annotator/uniformer/mmcv/visualization/color.py b/annotator/uniformer/mmcv/visualization/color.py new file mode 100644 index 0000000000000000000000000000000000000000..9041e0e6b7581c3356795d6a3c5e84667c88f025 --- /dev/null +++ b/annotator/uniformer/mmcv/visualization/color.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from enum import Enum + +import numpy as np + +from annotator.uniformer.mmcv.utils import is_str + + +class Color(Enum): + """An enum that defines common colors. + + Contains red, green, blue, cyan, yellow, magenta, white and black. + """ + red = (0, 0, 255) + green = (0, 255, 0) + blue = (255, 0, 0) + cyan = (255, 255, 0) + yellow = (0, 255, 255) + magenta = (255, 0, 255) + white = (255, 255, 255) + black = (0, 0, 0) + + +def color_val(color): + """Convert various input to color tuples. + + Args: + color (:obj:`Color`/str/tuple/int/ndarray): Color inputs + + Returns: + tuple[int]: A tuple of 3 integers indicating BGR channels. + """ + if is_str(color): + return Color[color].value + elif isinstance(color, Color): + return color.value + elif isinstance(color, tuple): + assert len(color) == 3 + for channel in color: + assert 0 <= channel <= 255 + return color + elif isinstance(color, int): + assert 0 <= color <= 255 + return color, color, color + elif isinstance(color, np.ndarray): + assert color.ndim == 1 and color.size == 3 + assert np.all((color >= 0) & (color <= 255)) + color = color.astype(np.uint8) + return tuple(color) + else: + raise TypeError(f'Invalid type for color: {type(color)}') diff --git a/annotator/uniformer/mmcv/visualization/image.py b/annotator/uniformer/mmcv/visualization/image.py new file mode 100644 index 0000000000000000000000000000000000000000..61a56c75b67f593c298408462c63c0468be8e276 --- /dev/null +++ b/annotator/uniformer/mmcv/visualization/image.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +from annotator.uniformer.mmcv.image import imread, imwrite +from .color import color_val + + +def imshow(img, win_name='', wait_time=0): + """Show an image. + + Args: + img (str or ndarray): The image to be displayed. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + """ + cv2.imshow(win_name, imread(img)) + if wait_time == 0: # prevent from hanging if windows was closed + while True: + ret = cv2.waitKey(1) + + closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1 + # if user closed window or if some key pressed + if closed or ret != -1: + break + else: + ret = cv2.waitKey(wait_time) + + +def imshow_bboxes(img, + bboxes, + colors='green', + top_k=-1, + thickness=1, + show=True, + win_name='', + wait_time=0, + out_file=None): + """Draw bboxes on an image. + + Args: + img (str or ndarray): The image to be displayed. + bboxes (list or ndarray): A list of ndarray of shape (k, 4). + colors (list[str or tuple or Color]): A list of colors. + top_k (int): Plot the first k bboxes only if set positive. + thickness (int): Thickness of lines. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str, optional): The filename to write the image. + + Returns: + ndarray: The image with bboxes drawn on it. + """ + img = imread(img) + img = np.ascontiguousarray(img) + + if isinstance(bboxes, np.ndarray): + bboxes = [bboxes] + if not isinstance(colors, list): + colors = [colors for _ in range(len(bboxes))] + colors = [color_val(c) for c in colors] + assert len(bboxes) == len(colors) + + for i, _bboxes in enumerate(bboxes): + _bboxes = _bboxes.astype(np.int32) + if top_k <= 0: + _top_k = _bboxes.shape[0] + else: + _top_k = min(top_k, _bboxes.shape[0]) + for j in range(_top_k): + left_top = (_bboxes[j, 0], _bboxes[j, 1]) + right_bottom = (_bboxes[j, 2], _bboxes[j, 3]) + cv2.rectangle( + img, left_top, right_bottom, colors[i], thickness=thickness) + + if show: + imshow(img, win_name, wait_time) + if out_file is not None: + imwrite(img, out_file) + return img + + +def imshow_det_bboxes(img, + bboxes, + labels, + class_names=None, + score_thr=0, + bbox_color='green', + text_color='green', + thickness=1, + font_scale=0.5, + show=True, + win_name='', + wait_time=0, + out_file=None): + """Draw bboxes and class labels (with scores) on an image. + + Args: + img (str or ndarray): The image to be displayed. + bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or + (n, 5). + labels (ndarray): Labels of bboxes. + class_names (list[str]): Names of each classes. + score_thr (float): Minimum score of bboxes to be shown. + bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str or None): The filename to write the image. + + Returns: + ndarray: The image with bboxes drawn on it. + """ + assert bboxes.ndim == 2 + assert labels.ndim == 1 + assert bboxes.shape[0] == labels.shape[0] + assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5 + img = imread(img) + img = np.ascontiguousarray(img) + + if score_thr > 0: + assert bboxes.shape[1] == 5 + scores = bboxes[:, -1] + inds = scores > score_thr + bboxes = bboxes[inds, :] + labels = labels[inds] + + bbox_color = color_val(bbox_color) + text_color = color_val(text_color) + + for bbox, label in zip(bboxes, labels): + bbox_int = bbox.astype(np.int32) + left_top = (bbox_int[0], bbox_int[1]) + right_bottom = (bbox_int[2], bbox_int[3]) + cv2.rectangle( + img, left_top, right_bottom, bbox_color, thickness=thickness) + label_text = class_names[ + label] if class_names is not None else f'cls {label}' + if len(bbox) > 4: + label_text += f'|{bbox[-1]:.02f}' + cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2), + cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) + + if show: + imshow(img, win_name, wait_time) + if out_file is not None: + imwrite(img, out_file) + return img diff --git a/annotator/uniformer/mmcv/visualization/optflow.py b/annotator/uniformer/mmcv/visualization/optflow.py new file mode 100644 index 0000000000000000000000000000000000000000..c3870c700f7c946177ee5d536ce3f6c814a77ce7 --- /dev/null +++ b/annotator/uniformer/mmcv/visualization/optflow.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import division + +import numpy as np + +from annotator.uniformer.mmcv.image import rgb2bgr +from annotator.uniformer.mmcv.video import flowread +from .image import imshow + + +def flowshow(flow, win_name='', wait_time=0): + """Show optical flow. + + Args: + flow (ndarray or str): The optical flow to be displayed. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + """ + flow = flowread(flow) + flow_img = flow2rgb(flow) + imshow(rgb2bgr(flow_img), win_name, wait_time) + + +def flow2rgb(flow, color_wheel=None, unknown_thr=1e6): + """Convert flow map to RGB image. + + Args: + flow (ndarray): Array of optical flow. + color_wheel (ndarray or None): Color wheel used to map flow field to + RGB colorspace. Default color wheel will be used if not specified. + unknown_thr (str): Values above this threshold will be marked as + unknown and thus ignored. + + Returns: + ndarray: RGB image that can be visualized. + """ + assert flow.ndim == 3 and flow.shape[-1] == 2 + if color_wheel is None: + color_wheel = make_color_wheel() + assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3 + num_bins = color_wheel.shape[0] + + dx = flow[:, :, 0].copy() + dy = flow[:, :, 1].copy() + + ignore_inds = ( + np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) | + (np.abs(dy) > unknown_thr)) + dx[ignore_inds] = 0 + dy[ignore_inds] = 0 + + rad = np.sqrt(dx**2 + dy**2) + if np.any(rad > np.finfo(float).eps): + max_rad = np.max(rad) + dx /= max_rad + dy /= max_rad + + rad = np.sqrt(dx**2 + dy**2) + angle = np.arctan2(-dy, -dx) / np.pi + + bin_real = (angle + 1) / 2 * (num_bins - 1) + bin_left = np.floor(bin_real).astype(int) + bin_right = (bin_left + 1) % num_bins + w = (bin_real - bin_left.astype(np.float32))[..., None] + flow_img = (1 - + w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :] + small_ind = rad <= 1 + flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind]) + flow_img[np.logical_not(small_ind)] *= 0.75 + + flow_img[ignore_inds, :] = 0 + + return flow_img + + +def make_color_wheel(bins=None): + """Build a color wheel. + + Args: + bins(list or tuple, optional): Specify the number of bins for each + color range, corresponding to six ranges: red -> yellow, + yellow -> green, green -> cyan, cyan -> blue, blue -> magenta, + magenta -> red. [15, 6, 4, 11, 13, 6] is used for default + (see Middlebury). + + Returns: + ndarray: Color wheel of shape (total_bins, 3). + """ + if bins is None: + bins = [15, 6, 4, 11, 13, 6] + assert len(bins) == 6 + + RY, YG, GC, CB, BM, MR = tuple(bins) + + ry = [1, np.arange(RY) / RY, 0] + yg = [1 - np.arange(YG) / YG, 1, 0] + gc = [0, 1, np.arange(GC) / GC] + cb = [0, 1 - np.arange(CB) / CB, 1] + bm = [np.arange(BM) / BM, 0, 1] + mr = [1, 0, 1 - np.arange(MR) / MR] + + num_bins = RY + YG + GC + CB + BM + MR + + color_wheel = np.zeros((3, num_bins), dtype=np.float32) + + col = 0 + for i, color in enumerate([ry, yg, gc, cb, bm, mr]): + for j in range(3): + color_wheel[j, col:col + bins[i]] = color[j] + col += bins[i] + + return color_wheel.T diff --git a/annotator/uniformer/mmcv_custom/__init__.py b/annotator/uniformer/mmcv_custom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536 --- /dev/null +++ b/annotator/uniformer/mmcv_custom/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import load_checkpoint + +__all__ = ['load_checkpoint'] \ No newline at end of file diff --git a/annotator/uniformer/mmcv_custom/checkpoint.py b/annotator/uniformer/mmcv_custom/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..19b87fef0a52d31babcdb3edb8f3089b6420173f --- /dev/null +++ b/annotator/uniformer/mmcv_custom/checkpoint.py @@ -0,0 +1,500 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo +from torch.nn import functional as F + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.fileio import FileClient +from annotator.uniformer.mmcv.fileio import load as load_file +from annotator.uniformer.mmcv.parallel import is_module_wrapper +from annotator.uniformer.mmcv.utils import mkdir_or_exist +from annotator.uniformer.mmcv.runner import get_dist_info + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H*W: + logger.warning("Error in loading absolute_pos_embed, pass") + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f"Error in loading {table_key}, pass") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() \ No newline at end of file diff --git a/annotator/uniformer/mmseg/apis/__init__.py b/annotator/uniformer/mmseg/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165 --- /dev/null +++ b/annotator/uniformer/mmseg/apis/__init__.py @@ -0,0 +1,9 @@ +from .inference import inference_segmentor, init_segmentor, show_result_pyplot +from .test import multi_gpu_test, single_gpu_test +from .train import get_root_logger, set_random_seed, train_segmentor + +__all__ = [ + 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', + 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', + 'show_result_pyplot' +] diff --git a/annotator/uniformer/mmseg/apis/inference.py b/annotator/uniformer/mmseg/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..90bc1c0c68525734bd6793f07c15fe97d3c8342c --- /dev/null +++ b/annotator/uniformer/mmseg/apis/inference.py @@ -0,0 +1,136 @@ +import matplotlib.pyplot as plt +import annotator.uniformer.mmcv as mmcv +import torch +from annotator.uniformer.mmcv.parallel import collate, scatter +from annotator.uniformer.mmcv.runner import load_checkpoint + +from annotator.uniformer.mmseg.datasets.pipelines import Compose +from annotator.uniformer.mmseg.models import build_segmentor + + +def init_segmentor(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + config.model.pretrained = None + config.model.train_cfg = None + model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.CLASSES = checkpoint['meta']['CLASSES'] + model.PALETTE = checkpoint['meta']['PALETTE'] + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +class LoadImage: + """A simple pipeline to load image.""" + + def __call__(self, results): + """Call function to load images into results. + + Args: + results (dict): A result dict contains the file name + of the image to be read. + + Returns: + dict: ``results`` will be returned containing loaded image. + """ + + if isinstance(results['img'], str): + results['filename'] = results['img'] + results['ori_filename'] = results['img'] + else: + results['filename'] = None + results['ori_filename'] = None + img = mmcv.imread(results['img']) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + return results + + +def inference_segmentor(model, img): + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + (list[Tensor]): The segmentation result. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + data['img_metas'] = [i.data[0] for i in data['img_metas']] + + # forward the model + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def show_result_pyplot(model, + img, + result, + palette=None, + fig_size=(15, 10), + opacity=0.5, + title='', + block=True): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (list): The segmentation result. + palette (list[list[int]]] | None): The palette of segmentation + map. If None is given, random palette will be generated. + Default: None + fig_size (tuple): Figure size of the pyplot figure. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + block (bool): Whether to block the pyplot figure. + Default is True. + """ + if hasattr(model, 'module'): + model = model.module + img = model.show_result( + img, result, palette=palette, show=False, opacity=opacity) + # plt.figure(figsize=fig_size) + # plt.imshow(mmcv.bgr2rgb(img)) + # plt.title(title) + # plt.tight_layout() + # plt.show(block=block) + return mmcv.bgr2rgb(img) diff --git a/annotator/uniformer/mmseg/apis/test.py b/annotator/uniformer/mmseg/apis/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e574eb7da04f09a59cf99ff953c36468ae87a326 --- /dev/null +++ b/annotator/uniformer/mmseg/apis/test.py @@ -0,0 +1,238 @@ +import os.path as osp +import pickle +import shutil +import tempfile + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch +import torch.distributed as dist +from annotator.uniformer.mmcv.image import tensor2imgs +from annotator.uniformer.mmcv.runner import get_dist_info + + +def np2tmp(array, temp_file_name=None): + """Save ndarray to local numpy file. + + Args: + array (ndarray): Ndarray to save. + temp_file_name (str): Numpy file name. If 'temp_file_name=None', this + function will generate a file name with tempfile.NamedTemporaryFile + to save ndarray. Default: None. + + Returns: + str: The numpy file name. + """ + + if temp_file_name is None: + temp_file_name = tempfile.NamedTemporaryFile( + suffix='.npy', delete=False).name + np.save(temp_file_name, array) + return temp_file_name + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + efficient_test=False, + opacity=0.5): + """Test with single GPU. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + show (bool): Whether show results during inference. Default: False. + out_dir (str, optional): If specified, the results will be dumped into + the directory to save output results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + list: The prediction results. + """ + + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + + if show or out_dir: + img_tensor = data['img'][0] + img_metas = data['img_metas'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + + for img, img_meta in zip(imgs, img_metas): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + ori_h, ori_w = img_meta['ori_shape'][:-1] + img_show = mmcv.imresize(img_show, (ori_w, ori_h)) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result, + palette=dataset.PALETTE, + show=show, + out_file=out_file, + opacity=opacity) + + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, + data_loader, + tmpdir=None, + gpu_collect=False, + efficient_test=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + + Returns: + list: The prediction results. + """ + + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + + if rank == 0: + batch_size = data['img'][0].size(0) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results with CPU.""" + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank))) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results with GPU.""" + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append( + pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/annotator/uniformer/mmseg/apis/train.py b/annotator/uniformer/mmseg/apis/train.py new file mode 100644 index 0000000000000000000000000000000000000000..63f319a919ff023931a6a663e668f27dd1a07a2e --- /dev/null +++ b/annotator/uniformer/mmseg/apis/train.py @@ -0,0 +1,116 @@ +import random +import warnings + +import numpy as np +import torch +from annotator.uniformer.mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from annotator.uniformer.mmcv.runner import build_optimizer, build_runner + +from annotator.uniformer.mmseg.core import DistEvalHook, EvalHook +from annotator.uniformer.mmseg.datasets import build_dataloader, build_dataset +from annotator.uniformer.mmseg.utils import get_root_logger + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_segmentor(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None): + """Launch segmentor training.""" + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + drop_last=True) for ds in dataset + ] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + model = MMDataParallel( + model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if cfg.get('runner') is None: + cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # register hooks + runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, + cfg.checkpoint_config, cfg.log_config, + cfg.get('momentum_config', None)) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + # register eval hooks + if validate: + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW') + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) diff --git a/annotator/uniformer/mmseg/core/__init__.py b/annotator/uniformer/mmseg/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023 --- /dev/null +++ b/annotator/uniformer/mmseg/core/__init__.py @@ -0,0 +1,3 @@ +from .evaluation import * # noqa: F401, F403 +from .seg import * # noqa: F401, F403 +from .utils import * # noqa: F401, F403 diff --git a/annotator/uniformer/mmseg/core/evaluation/__init__.py b/annotator/uniformer/mmseg/core/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/__init__.py @@ -0,0 +1,8 @@ +from .class_names import get_classes, get_palette +from .eval_hooks import DistEvalHook, EvalHook +from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou + +__all__ = [ + 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', + 'eval_metrics', 'get_classes', 'get_palette' +] diff --git a/annotator/uniformer/mmseg/core/evaluation/class_names.py b/annotator/uniformer/mmseg/core/evaluation/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..ffae816cf980ce4b03e491cc0c4298cb823797e6 --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/class_names.py @@ -0,0 +1,152 @@ +import annotator.uniformer.mmcv as mmcv + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc100c8f96e817a6ed2666f7c9f762af2463b48 --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py @@ -0,0 +1,109 @@ +import os.path as osp + +from annotator.uniformer.mmcv.runner import DistEvalHook as _DistEvalHook +from annotator.uniformer.mmcv.runner import EvalHook as _EvalHook + + +class EvalHook(_EvalHook): + """Single GPU EvalHook, with efficient test support. + + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + Returns: + list: The prediction results. + """ + + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.efficient_test = efficient_test + + def after_train_iter(self, runner): + """After train epoch hook. + + Override default ``single_gpu_test``. + """ + if self.by_epoch or not self.every_n_iters(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import single_gpu_test + runner.log_buffer.clear() + results = single_gpu_test( + runner.model, + self.dataloader, + show=False, + efficient_test=self.efficient_test) + self.evaluate(runner, results) + + def after_train_epoch(self, runner): + """After train epoch hook. + + Override default ``single_gpu_test``. + """ + if not self.by_epoch or not self.every_n_epochs(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import single_gpu_test + runner.log_buffer.clear() + results = single_gpu_test(runner.model, self.dataloader, show=False) + self.evaluate(runner, results) + + +class DistEvalHook(_DistEvalHook): + """Distributed EvalHook, with efficient test support. + + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + Returns: + list: The prediction results. + """ + + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.efficient_test = efficient_test + + def after_train_iter(self, runner): + """After train epoch hook. + + Override default ``multi_gpu_test``. + """ + if self.by_epoch or not self.every_n_iters(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import multi_gpu_test + runner.log_buffer.clear() + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=osp.join(runner.work_dir, '.eval_hook'), + gpu_collect=self.gpu_collect, + efficient_test=self.efficient_test) + if runner.rank == 0: + print('\n') + self.evaluate(runner, results) + + def after_train_epoch(self, runner): + """After train epoch hook. + + Override default ``multi_gpu_test``. + """ + if not self.by_epoch or not self.every_n_epochs(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import multi_gpu_test + runner.log_buffer.clear() + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=osp.join(runner.work_dir, '.eval_hook'), + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + self.evaluate(runner, results) diff --git a/annotator/uniformer/mmseg/core/evaluation/metrics.py b/annotator/uniformer/mmseg/core/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..16c7dd47cadd53cf1caaa194e28a343f2aacc599 --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/metrics.py @@ -0,0 +1,326 @@ +from collections import OrderedDict + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch + + +def f_score(precision, recall, beta=1): + """calcuate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined score. + Default: False. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + +def intersect_and_union(pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate intersection and Union. + + Args: + pred_label (ndarray | str): Prediction segmentation map + or predict result filename. + label (ndarray | str): Ground truth segmentation map + or label filename. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. The parameter will + work only when label is str. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. The parameter will + work only when label is str. Default: False. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + if isinstance(pred_label, str): + pred_label = torch.from_numpy(np.load(pred_label)) + else: + pred_label = torch.from_numpy((pred_label)) + + if isinstance(label, str): + label = torch.from_numpy( + mmcv.imread(label, flag='unchanged', backend='pillow')) + else: + label = torch.from_numpy(label) + + if label_map is not None: + for old_id, new_id in label_map.items(): + label[label == old_id] = new_id + if reduce_zero_label: + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + +def total_intersect_and_union(results, + gt_seg_maps, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate Total Intersection and Union. + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes. + ndarray: The union of prediction and ground truth histogram on all + classes. + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + num_imgs = len(results) + assert len(gt_seg_maps) == num_imgs + total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) + for i in range(num_imgs): + area_intersect, area_union, area_pred_label, area_label = \ + intersect_and_union( + results[i], gt_seg_maps[i], num_classes, ignore_index, + label_map, reduce_zero_label) + total_area_intersect += area_intersect + total_area_union += area_union + total_area_pred_label += area_pred_label + total_area_label += area_label + return total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label + + +def mean_iou(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + dict[str, float | ndarray]: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category IoU, shape (num_classes, ). + """ + iou_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mIoU'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return iou_result + + +def mean_dice(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Dice (mDice) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + dict[str, float | ndarray]: Default metrics. + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category dice, shape (num_classes, ). + """ + + dice_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mDice'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return dice_result + + +def mean_fscore(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + beta (int): Determines the weight of recall in the combined score. + Default: False. + + + Returns: + dict[str, float | ndarray]: Default metrics. + float: Overall accuracy on all images. + ndarray: Per category recall, shape (num_classes, ). + ndarray: Per category precision, shape (num_classes, ). + ndarray: Per category f-score, shape (num_classes, ). + """ + fscore_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mFscore'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label, + beta=beta) + return fscore_result + + +def eval_metrics(results, + gt_seg_maps, + num_classes, + ignore_index, + metrics=['mIoU'], + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1): + """Calculate evaluation metrics + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError('metrics {} is not supported'.format(metrics)) + + total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label = total_intersect_and_union( + results, gt_seg_maps, num_classes, ignore_index, label_map, + reduce_zero_label) + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor( + [f_score(x[0], x[1], beta) for x in zip(precision, recall)]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/annotator/uniformer/mmseg/core/seg/__init__.py b/annotator/uniformer/mmseg/core/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_pixel_sampler +from .sampler import BasePixelSampler, OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/annotator/uniformer/mmseg/core/seg/builder.py b/annotator/uniformer/mmseg/core/seg/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..db61f03d4abb2072f2532ce4429c0842495e015b --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/builder.py @@ -0,0 +1,8 @@ +from annotator.uniformer.mmcv.utils import Registry, build_from_cfg + +PIXEL_SAMPLERS = Registry('pixel sampler') + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) diff --git a/annotator/uniformer/mmseg/core/seg/sampler/__init__.py b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1 --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py @@ -0,0 +1,4 @@ +from .base_pixel_sampler import BasePixelSampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['BasePixelSampler', 'OHEMPixelSampler'] diff --git a/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57 --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py @@ -0,0 +1,12 @@ +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F + +from ..builder import PIXEL_SAMPLERS +from .base_pixel_sampler import BasePixelSampler + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super(OHEMPixelSampler, self).__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + losses = self.context.loss_decode( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/annotator/uniformer/mmseg/core/utils/__init__.py b/annotator/uniformer/mmseg/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4 --- /dev/null +++ b/annotator/uniformer/mmseg/core/utils/__init__.py @@ -0,0 +1,3 @@ +from .misc import add_prefix + +__all__ = ['add_prefix'] diff --git a/annotator/uniformer/mmseg/core/utils/misc.py b/annotator/uniformer/mmseg/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466 --- /dev/null +++ b/annotator/uniformer/mmseg/core/utils/misc.py @@ -0,0 +1,17 @@ +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/annotator/uniformer/mmseg/datasets/__init__.py b/annotator/uniformer/mmseg/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/__init__.py @@ -0,0 +1,19 @@ +from .ade import ADE20KDataset +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +from .chase_db1 import ChaseDB1Dataset +from .cityscapes import CityscapesDataset +from .custom import CustomDataset +from .dataset_wrappers import ConcatDataset, RepeatDataset +from .drive import DRIVEDataset +from .hrf import HRFDataset +from .pascal_context import PascalContextDataset, PascalContextDataset59 +from .stare import STAREDataset +from .voc import PascalVOCDataset + +__all__ = [ + 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', + 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', + 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', + 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', + 'STAREDataset' +] diff --git a/annotator/uniformer/mmseg/datasets/ade.py b/annotator/uniformer/mmseg/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/ade.py @@ -0,0 +1,84 @@ +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ADE20KDataset(CustomDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + CLASSES = ( + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + def __init__(self, **kwargs): + super(ADE20KDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) diff --git a/annotator/uniformer/mmseg/datasets/builder.py b/annotator/uniformer/mmseg/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0798b14cd8b39fc58d8f2a4930f1e079b5bf8b55 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/builder.py @@ -0,0 +1,169 @@ +import copy +import platform +import random +from functools import partial + +import numpy as np +from annotator.uniformer.mmcv.parallel import collate +from annotator.uniformer.mmcv.runner import get_dist_info +from annotator.uniformer.mmcv.utils import Registry, build_from_cfg +from annotator.uniformer.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader +from torch.utils.data import DistributedSampler + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + hard_limit = rlimit[1] + soft_limit = min(4096, hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def _concat_dataset(cfg, default_args=None): + """Build :obj:`ConcatDataset by.""" + from .dataset_wrappers import ConcatDataset + img_dir = cfg['img_dir'] + ann_dir = cfg.get('ann_dir', None) + split = cfg.get('split', None) + num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 + if ann_dir is not None: + num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 + else: + num_ann_dir = 0 + if split is not None: + num_split = len(split) if isinstance(split, (list, tuple)) else 1 + else: + num_split = 0 + if num_img_dir > 1: + assert num_img_dir == num_ann_dir or num_ann_dir == 0 + assert num_img_dir == num_split or num_split == 0 + else: + assert num_split == num_ann_dir or num_ann_dir <= 1 + num_dset = max(num_split, num_img_dir) + + datasets = [] + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + if isinstance(img_dir, (list, tuple)): + data_cfg['img_dir'] = img_dir[i] + if isinstance(ann_dir, (list, tuple)): + data_cfg['ann_dir'] = ann_dir[i] + if isinstance(split, (list, tuple)): + data_cfg['split'] = split[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets) + + +def build_dataset(cfg, default_args=None): + """Build datasets.""" + from .dataset_wrappers import ConcatDataset, RepeatDataset + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( + cfg.get('split', None), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + drop_last=False, + pin_memory=True, + dataloader_type='PoolDataLoader', + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int | None): Seed to be used. Default: None. + drop_last (bool): Whether to drop the last incomplete batch in epoch. + Default: False + pin_memory (bool): Whether to use pin_memory in DataLoader. + Default: True + dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=shuffle) + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + assert dataloader_type in ( + 'DataLoader', + 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' + + if dataloader_type == 'PoolDataLoader': + dataloader = PoolDataLoader + elif dataloader_type == 'DataLoader': + dataloader = DataLoader + + data_loader = dataloader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/annotator/uniformer/mmseg/datasets/chase_db1.py b/annotator/uniformer/mmseg/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/chase_db1.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ChaseDB1Dataset(CustomDataset): + """Chase_db1 dataset. + + In segmentation map annotation for Chase_db1, 0 stands for background, + which is included in 2 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_1stHO.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(ChaseDB1Dataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_1stHO.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/cityscapes.py b/annotator/uniformer/mmseg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..81e47a914a1aa2e5458e18669d65ffb742f46fc6 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/cityscapes.py @@ -0,0 +1,217 @@ +import os.path as osp +import tempfile + +import annotator.uniformer.mmcv as mmcv +import numpy as np +from annotator.uniformer.mmcv.utils import print_log +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class CityscapesDataset(CustomDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + + CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle') + + PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], + [0, 80, 100], [0, 0, 230], [119, 11, 32]] + + def __init__(self, **kwargs): + super(CityscapesDataset, self).__init__( + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtFine_labelTrainIds.png', + **kwargs) + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + import cityscapesscripts.helpers.labels as CSLabels + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy + + def results2img(self, results, imgfile_prefix, to_label_id): + """Write the segmentation results to images. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + prog_bar = mmcv.ProgressBar(len(self)) + for idx in range(len(self)): + result = results[idx] + if to_label_id: + result = self._convert_to_label_id(result) + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + output = Image.fromarray(result.astype(np.uint8)).convert('P') + import cityscapesscripts.helpers.labels as CSLabels + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) + for label_id, label in CSLabels.id2label.items(): + palette[label_id] = label.color + + output.putpalette(palette) + output.save(png_filename) + result_files.append(png_filename) + prog_bar.update() + + return result_files + + def format_results(self, results, imgfile_prefix=None, to_label_id=True): + """Format the results into dir (standard format for Cityscapes + evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str | None): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". If not specified, a temp file will be created. + Default: None. + to_label_id (bool): whether convert output to label_id for + submission. Default: False + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(self), ( + 'The length of results is not equal to the dataset len: ' + f'{len(results)} != {len(self)}') + + if imgfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + imgfile_prefix = tmp_dir.name + else: + tmp_dir = None + result_files = self.results2img(results, imgfile_prefix, to_label_id) + + return result_files, tmp_dir + + def evaluate(self, + results, + metric='mIoU', + logger=None, + imgfile_prefix=None, + efficient_test=False): + """Evaluation in Cityscapes/default protocol. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file, + for cityscapes evaluation only. It includes the file path and + the prefix of filename, e.g., "a/b/prefix". + If results are evaluated with cityscapes protocol, it would be + the prefix of output png files. The output files would be + png images under folder "a/b/prefix/xxx.png", where "xxx" is + the image name of cityscapes. If not specified, a temp file + will be created for evaluation. + Default: None. + + Returns: + dict[str, float]: Cityscapes/default metrics. + """ + + eval_results = dict() + metrics = metric.copy() if isinstance(metric, list) else [metric] + if 'cityscapes' in metrics: + eval_results.update( + self._evaluate_cityscapes(results, logger, imgfile_prefix)) + metrics.remove('cityscapes') + if len(metrics) > 0: + eval_results.update( + super(CityscapesDataset, + self).evaluate(results, metrics, logger, efficient_test)) + + return eval_results + + def _evaluate_cityscapes(self, results, logger, imgfile_prefix): + """Evaluation in Cityscapes protocol. + + Args: + results (list): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + try: + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + except ImportError: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + result_files, tmp_dir = self.format_results(results, imgfile_prefix) + + if tmp_dir is None: + result_dir = imgfile_prefix + else: + result_dir = tmp_dir.name + + eval_results = dict() + print_log(f'Evaluating results under {result_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(result_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + seg_map_list = [] + pred_list = [] + + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + for seg_map in mmcv.scandir( + self.ann_dir, 'gtFine_labelIds.png', recursive=True): + seg_map_list.append(osp.join(self.ann_dir, seg_map)) + pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) + + eval_results.update( + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + + if tmp_dir is not None: + tmp_dir.cleanup() + + return eval_results diff --git a/annotator/uniformer/mmseg/datasets/custom.py b/annotator/uniformer/mmseg/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..d8eb2a709cc7a3a68fc6a1e3a1ad98faef4c5b7b --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/custom.py @@ -0,0 +1,400 @@ +import os +import os.path as osp +from collections import OrderedDict +from functools import reduce + +import annotator.uniformer.mmcv as mmcv +import numpy as np +from annotator.uniformer.mmcv.utils import print_log +from prettytable import PrettyTable +from torch.utils.data import Dataset + +from annotator.uniformer.mmseg.core import eval_metrics +from annotator.uniformer.mmseg.utils import get_root_logger +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class CustomDataset(Dataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of CustomDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/tutorials/new_dataset.md`` for more details. + + + Args: + pipeline (list[dict]): Processing pipeline + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. Default: '.jpg' + ann_dir (str, optional): Path to annotation directory. Default: None + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + split (str, optional): Split txt file. If split is specified, only + file with suffix in the splits will be loaded. Otherwise, all + images in img_dir/ann_dir will be loaded. Default: None + data_root (str, optional): Data root for img_dir/ann_dir. Default: + None. + test_mode (bool): If test_mode=True, gt wouldn't be loaded. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default: False + classes (str | Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, and + self.PALETTE is None, random palette will be generated. + Default: None + """ + + CLASSES = None + + PALETTE = None + + def __init__(self, + pipeline, + img_dir, + img_suffix='.jpg', + ann_dir=None, + seg_map_suffix='.png', + split=None, + data_root=None, + test_mode=False, + ignore_index=255, + reduce_zero_label=False, + classes=None, + palette=None): + self.pipeline = Compose(pipeline) + self.img_dir = img_dir + self.img_suffix = img_suffix + self.ann_dir = ann_dir + self.seg_map_suffix = seg_map_suffix + self.split = split + self.data_root = data_root + self.test_mode = test_mode + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.label_map = None + self.CLASSES, self.PALETTE = self.get_classes_and_palette( + classes, palette) + + # join paths if data_root is specified + if self.data_root is not None: + if not osp.isabs(self.img_dir): + self.img_dir = osp.join(self.data_root, self.img_dir) + if not (self.ann_dir is None or osp.isabs(self.ann_dir)): + self.ann_dir = osp.join(self.data_root, self.ann_dir) + if not (self.split is None or osp.isabs(self.split)): + self.split = osp.join(self.data_root, self.split) + + # load annotations + self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, + self.ann_dir, + self.seg_map_suffix, self.split) + + def __len__(self): + """Total number of samples of data.""" + return len(self.img_infos) + + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, + split): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + with open(split) as f: + for line in f: + img_name = line.strip() + img_info = dict(filename=img_name + img_suffix) + if ann_dir is not None: + seg_map = img_name + seg_map_suffix + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + else: + for img in mmcv.scandir(img_dir, img_suffix, recursive=True): + img_info = dict(filename=img) + if ann_dir is not None: + seg_map = img.replace(img_suffix, seg_map_suffix) + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + + print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + return img_infos + + def get_ann_info(self, idx): + """Get annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.img_infos[idx]['ann'] + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['seg_fields'] = [] + results['img_prefix'] = self.img_dir + results['seg_prefix'] = self.ann_dir + if self.custom_classes: + results['label_map'] = self.label_map + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set + False). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + else: + return self.prepare_train_img(idx) + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + + img_info = self.img_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by + pipeline. + """ + + img_info = self.img_infos[idx] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def format_results(self, results, **kwargs): + """Place holder to format result to dataset specific output.""" + + def get_gt_seg_maps(self, efficient_test=False): + """Get ground truth segmentation maps for evaluation.""" + gt_seg_maps = [] + for img_info in self.img_infos: + seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) + if efficient_test: + gt_seg_map = seg_map + else: + gt_seg_map = mmcv.imread( + seg_map, flag='unchanged', backend='pillow') + gt_seg_maps.append(gt_seg_map) + return gt_seg_maps + + def get_classes_and_palette(self, classes=None, palette=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, random + palette will be generated. Default: None + """ + if classes is None: + self.custom_classes = False + return self.CLASSES, self.PALETTE + + self.custom_classes = True + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if self.CLASSES: + if not set(classes).issubset(self.CLASSES): + raise ValueError('classes is not a subset of CLASSES.') + + # dictionary, its keys are the old label ids and its values + # are the new label ids. + # used for changing pixel labels in load_annotations. + self.label_map = {} + for i, c in enumerate(self.CLASSES): + if c not in class_names: + self.label_map[i] = -1 + else: + self.label_map[i] = classes.index(c) + + palette = self.get_palette_for_custom_classes(class_names, palette) + + return class_names, palette + + def get_palette_for_custom_classes(self, class_names, palette=None): + + if self.label_map is not None: + # return subset of palette + palette = [] + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != -1: + palette.append(self.PALETTE[old_id]) + palette = type(self.PALETTE)(palette) + + elif palette is None: + if self.PALETTE is None: + palette = np.random.randint(0, 255, size=(len(class_names), 3)) + else: + palette = self.PALETTE + + return palette + + def evaluate(self, + results, + metric='mIoU', + logger=None, + efficient_test=False, + **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. 'mIoU', + 'mDice' and 'mFscore' are supported. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str, float]: Default metrics. + """ + + if isinstance(metric, str): + metric = [metric] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metric).issubset(set(allowed_metrics)): + raise KeyError('metric {} is not supported'.format(metric)) + eval_results = {} + gt_seg_maps = self.get_gt_seg_maps(efficient_test) + if self.CLASSES is None: + num_classes = len( + reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) + else: + num_classes = len(self.CLASSES) + ret_metrics = eval_metrics( + results, + gt_seg_maps, + num_classes, + self.ignore_index, + metric, + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label) + + if self.CLASSES is None: + class_names = tuple(range(num_classes)) + else: + class_names = self.CLASSES + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + + # for logger + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + summary_table_data = PrettyTable() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + summary_table_data.add_column(key, [val]) + else: + summary_table_data.add_column('m' + key, [val]) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + print_log('Summary:', logger) + print_log('\n' + summary_table_data.get_string(), logger=logger) + + # each metric dict + for key, value in ret_metrics_summary.items(): + if key == 'aAcc': + eval_results[key] = value / 100.0 + else: + eval_results['m' + key] = value / 100.0 + + ret_metrics_class.pop('Class', None) + for key, value in ret_metrics_class.items(): + eval_results.update({ + key + '.' + str(name): value[idx] / 100.0 + for idx, name in enumerate(class_names) + }) + + if mmcv.is_list_of(results, str): + for file_name in results: + os.remove(file_name) + return eval_results diff --git a/annotator/uniformer/mmseg/datasets/dataset_wrappers.py b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py @@ -0,0 +1,50 @@ +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .builder import DATASETS + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + """ + + def __init__(self, datasets): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + self.PALETTE = datasets[0].PALETTE + + +@DATASETS.register_module() +class RepeatDataset(object): + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + self.PALETTE = dataset.PALETTE + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + """Get item from original dataset.""" + return self.dataset[idx % self._ori_len] + + def __len__(self): + """The length is multiplied by ``times``""" + return self.times * self._ori_len diff --git a/annotator/uniformer/mmseg/datasets/drive.py b/annotator/uniformer/mmseg/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/drive.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class DRIVEDataset(CustomDataset): + """DRIVE dataset. + + In segmentation map annotation for DRIVE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(DRIVEDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_manual1.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/hrf.py b/annotator/uniformer/mmseg/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/hrf.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class HRFDataset(CustomDataset): + """HRF dataset. + + In segmentation map annotation for HRF, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(HRFDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/pascal_context.py b/annotator/uniformer/mmseg/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pascal_context.py @@ -0,0 +1,103 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalContextDataset(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', + 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', + 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', + 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', + 'floor', 'flower', 'food', 'grass', 'ground', 'horse', + 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', + 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', + 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', + 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', + 'window', 'wood') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) and self.split is not None + + +@DATASETS.register_module() +class PascalContextDataset59(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', + 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', + 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', + 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', + 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', + 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', + 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', + 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', + 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood') + + PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], + [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], + [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], + [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], + [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], + [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], + [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], + [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], + [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], + [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], + [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], + [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], + [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], + [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset59, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=True, + **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/annotator/uniformer/mmseg/datasets/pipelines/__init__.py b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py @@ -0,0 +1,16 @@ +from .compose import Compose +from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, + Transpose, to_tensor) +from .loading import LoadAnnotations, LoadImageFromFile +from .test_time_aug import MultiScaleFlipAug +from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, + PhotoMetricDistortion, RandomCrop, RandomFlip, + RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) + +__all__ = [ + 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', + 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', + 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', + 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' +] diff --git a/annotator/uniformer/mmseg/datasets/pipelines/compose.py b/annotator/uniformer/mmseg/datasets/pipelines/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfcbb925c6d4ebf849328b9f94ef6fc24359bf5 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/compose.py @@ -0,0 +1,51 @@ +import collections + +from annotator.uniformer.mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + """Compose multiple transforms sequentially. + + Args: + transforms (Sequence[dict | callable]): Sequence of transform object or + config dict to be composed. + """ + + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += f' {t}' + format_string += '\n)' + return format_string diff --git a/annotator/uniformer/mmseg/datasets/pipelines/formating.py b/annotator/uniformer/mmseg/datasets/pipelines/formating.py new file mode 100644 index 0000000000000000000000000000000000000000..97db85f4f9db39fb86ba77ead7d1a8407d810adb --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/formating.py @@ -0,0 +1,288 @@ +from collections.abc import Sequence + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch +from annotator.uniformer.mmcv.parallel import DataContainer as DC + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class ToTensor(object): + """Convert some results to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert data in results to :obj:`torch.Tensor`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class ImageToTensor(object): + """Convert image to :obj:`torch.Tensor` by given keys. + + The dimension order of input image is (H, W, C). The pipeline will convert + it to (C, H, W). If only 2 dimension (H, W) is given, the output would be + (1, H, W). + + Args: + keys (Sequence[str]): Key of images to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = to_tensor(img.transpose(2, 0, 1)) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose(object): + """Transpose some results by given keys. + + Args: + keys (Sequence[str]): Keys of results to be transposed. + order (Sequence[int]): Order of transpose. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class ToDataContainer(object): + """Convert results to :obj:`mmcv.DataContainer` by given fields. + + Args: + fields (Sequence[dict]): Each field is a dict like + ``dict(key='xxx', **kwargs)``. The ``key`` in result will + be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. + Default: ``(dict(key='img', stack=True), + dict(key='gt_semantic_seg'))``. + """ + + def __init__(self, + fields=(dict(key='img', + stack=True), dict(key='gt_semantic_seg'))): + self.fields = fields + + def __call__(self, results): + """Call function to convert data in results to + :obj:`mmcv.DataContainer`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted to + :obj:`mmcv.DataContainer`. + """ + + for field in self.fields: + field = field.copy() + key = field.pop('key') + results[key] = DC(results[key], **field) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(fields={self.fields})' + + +@PIPELINES.register_module() +class DefaultFormatBundle(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img" + and "gt_semantic_seg". These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with + default bundle. + """ + + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + if 'gt_semantic_seg' in results: + # convert to long + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, + ...].astype(np.int64)), + stack=True) + return results + + def __repr__(self): + return self.__class__.__name__ + + +@PIPELINES.register_module() +class Collect(object): + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. Typically keys + is set to some subset of "img", "gt_semantic_seg". + + The "img_meta" item is always populated. The contents of the "img_meta" + dictionary depends on "meta_keys". By default this includes: + + - "img_shape": shape of the image input to the network as a tuple + (h, w, c). Note that images may be zero padded on the bottom/right + if the batch tensor is larger than this shape. + + - "scale_factor": a float indicating the preprocessing scale + + - "flip": a boolean indicating if image flip transform was used + + - "filename": path to the image file + + - "ori_shape": original shape of the image as a tuple (h, w, c) + + - "pad_shape": image shape after padding + + - "img_norm_cfg": a dict of normalization information: + - mean - per channel mean subtraction + - std - per channel std divisor + - to_rgb - bool indicating if bgr was converted to rgb + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'img_norm_cfg')`` + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'img_norm_cfg')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + """Call function to collect keys in results. The keys in ``meta_keys`` + will be converted to :obj:mmcv.DataContainer. + + Args: + results (dict): Result dict contains the data to collect. + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``img_metas`` + """ + + data = {} + img_meta = {} + for key in self.meta_keys: + img_meta[key] = results[key] + data['img_metas'] = DC(img_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' diff --git a/annotator/uniformer/mmseg/datasets/pipelines/loading.py b/annotator/uniformer/mmseg/datasets/pipelines/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..d3692ae91f19b9c7ccf6023168788ff42c9e93e3 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/loading.py @@ -0,0 +1,153 @@ +import os.path as osp + +import annotator.uniformer.mmcv as mmcv +import numpy as np + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class LoadImageFromFile(object): + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`mmcv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'cv2' + """ + + def __init__(self, + to_float32=False, + color_type='color', + file_client_args=dict(backend='disk'), + imdecode_backend='cv2'): + self.to_float32 = to_float32 + self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('img_prefix') is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + img_bytes = self.file_client.get(filename) + img = mmcv.imfrombytes( + img_bytes, flag=self.color_type, backend=self.imdecode_backend) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results['img_norm_cfg'] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(to_float32={self.to_float32},' + repr_str += f"color_type='{self.color_type}'," + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str + + +@PIPELINES.register_module() +class LoadAnnotations(object): + """Load annotations for semantic segmentation. + + Args: + reduce_zero_label (bool): Whether reduce all label value by 1. + Usually used for datasets where 0 is background label. + Default: False. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'pillow' + """ + + def __init__(self, + reduce_zero_label=False, + file_client_args=dict(backend='disk'), + imdecode_backend='pillow'): + self.reduce_zero_label = reduce_zero_label + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('seg_prefix', None) is not None: + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + else: + filename = results['ann_info']['seg_map'] + img_bytes = self.file_client.get(filename) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + # modify if custom classes + if results.get('label_map', None) is not None: + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg == old_id] = new_id + # reduce zero_label + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + results['gt_semantic_seg'] = gt_semantic_seg + results['seg_fields'].append('gt_semantic_seg') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label},' + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str diff --git a/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1611a04d9d927223c9afbe5bf68af04d62937a --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py @@ -0,0 +1,133 @@ +import warnings + +import annotator.uniformer.mmcv as mmcv + +from ..builder import PIPELINES +from .compose import Compose + + +@PIPELINES.register_module() +class MultiScaleFlipAug(object): + """Test-time augmentation with multiple scales and flipping. + + An example configuration is as followed: + + .. code-block:: + + img_scale=(2048, 1024), + img_ratios=[0.5, 1.0], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ] + + After MultiScaleFLipAug with above configuration, the results are wrapped + into lists of the same length as followed: + + .. code-block:: + + dict( + img=[...], + img_shape=[...], + scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] + flip=[False, True, False, True] + ... + ) + + Args: + transforms (list[dict]): Transforms to apply in each augmentation. + img_scale (None | tuple | list[tuple]): Images scales for resizing. + img_ratios (float | list[float]): Image ratios for resizing + flip (bool): Whether apply flip augmentation. Default: False. + flip_direction (str | list[str]): Flip augmentation directions, + options are "horizontal" and "vertical". If flip_direction is list, + multiple flip augmentations will be applied. + It has no effect when flip == False. Default: "horizontal". + """ + + def __init__(self, + transforms, + img_scale, + img_ratios=None, + flip=False, + flip_direction='horizontal'): + self.transforms = Compose(transforms) + if img_ratios is not None: + img_ratios = img_ratios if isinstance(img_ratios, + list) else [img_ratios] + assert mmcv.is_list_of(img_ratios, float) + if img_scale is None: + # mode 1: given img_scale=None and a range of image ratio + self.img_scale = None + assert mmcv.is_list_of(img_ratios, float) + elif isinstance(img_scale, tuple) and mmcv.is_list_of( + img_ratios, float): + assert len(img_scale) == 2 + # mode 2: given a scale and a range of image ratio + self.img_scale = [(int(img_scale[0] * ratio), + int(img_scale[1] * ratio)) + for ratio in img_ratios] + else: + # mode 3: given multiple scales + self.img_scale = img_scale if isinstance(img_scale, + list) else [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None + self.flip = flip + self.img_ratios = img_ratios + self.flip_direction = flip_direction if isinstance( + flip_direction, list) else [flip_direction] + assert mmcv.is_list_of(self.flip_direction, str) + if not self.flip and self.flip_direction != ['horizontal']: + warnings.warn( + 'flip_direction has no effect when flip is set to False') + if (self.flip + and not any([t['type'] == 'RandomFlip' for t in transforms])): + warnings.warn( + 'flip has no effect when RandomFlip is not in transforms') + + def __call__(self, results): + """Call function to apply test time augment transforms on results. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + + aug_data = [] + if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): + h, w = results['img'].shape[:2] + img_scale = [(int(w * ratio), int(h * ratio)) + for ratio in self.img_ratios] + else: + img_scale = self.img_scale + flip_aug = [False, True] if self.flip else [False] + for scale in img_scale: + for flip in flip_aug: + for direction in self.flip_direction: + _results = results.copy() + _results['scale'] = scale + _results['flip'] = flip + _results['flip_direction'] = direction + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'img_scale={self.img_scale}, flip={self.flip})' + repr_str += f'flip_direction={self.flip_direction}' + return repr_str diff --git a/annotator/uniformer/mmseg/datasets/pipelines/transforms.py b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..94e869b252ef6d8b43604add2bbc02f034614bfb --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py @@ -0,0 +1,889 @@ +import annotator.uniformer.mmcv as mmcv +import numpy as np +from annotator.uniformer.mmcv.utils import deprecated_api_warning, is_tuple_of +from numpy import random + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Resize(object): + """Resize images & seg. + + This transform resizes the input image to some scale. If the input dict + contains the key "scale", then the scale in the input dict is used, + otherwise the specified scale in the init method is used. + + ``img_scale`` can be None, a tuple (single-scale) or a list of tuple + (multi-scale). There are 4 multiscale modes: + + - ``ratio_range is not None``: + 1. When img_scale is None, img_scale is the shape of image in results + (img_scale = results['img'].shape[:2]) and the image is resized based + on the original size. (mode 1) + 2. When img_scale is a tuple (single-scale), randomly sample a ratio from + the ratio range and multiply it with the image scale. (mode 2) + + - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a + scale from the a range. (mode 3) + + - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a + scale from multiple scales. (mode 4) + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given img_scale=None and a range of image ratio + # mode 2: given a scale and a range of image ratio + assert self.img_scale is None or len(self.img_scale) == 1 + else: + # mode 3 and 4: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + if self.img_scale is None: + h, w = results['img'].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) + else: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results['img'], results['scale'], return_scale=True) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results['img'], results['scale'], return_scale=True) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape # in case that there is no padding + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], results['scale'], interpolation='nearest') + else: + gt_seg = mmcv.imresize( + results[key], results['scale'], interpolation='nearest') + results[key] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + self._random_scale(results) + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +@PIPELINES.register_module() +class RandomFlip(object): + """Flip the image & seg. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + Args: + prob (float, optional): The flipping probability. Default: None. + direction(str, optional): The flipping direction. Options are + 'horizontal' and 'vertical'. Default: 'horizontal'. + """ + + @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') + def __init__(self, prob=None, direction='horizontal'): + self.prob = prob + self.direction = direction + if prob is not None: + assert prob >= 0 and prob <= 1 + assert direction in ['horizontal', 'vertical'] + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added into + result dict. + """ + + if 'flip' not in results: + flip = True if np.random.rand() < self.prob else False + results['flip'] = flip + if 'flip_direction' not in results: + results['flip_direction'] = self.direction + if results['flip']: + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + # use copy() to make numpy stride positive + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']).copy() + return results + + def __repr__(self): + return self.__class__.__name__ + f'(prob={self.prob})' + + +@PIPELINES.register_module() +class Pad(object): + """Pad the image & mask. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + """ + + def __init__(self, + size=None, + size_divisor=None, + pad_val=0, + seg_pad_val=255): + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + # only one of size and size_divisor should be valid + assert size is not None or size_divisor is not None + assert size is None or size_divisor is None + + def _pad_img(self, results): + """Pad images according to ``self.size``.""" + if self.size is not None: + padded_img = mmcv.impad( + results['img'], shape=self.size, pad_val=self.pad_val) + elif self.size_divisor is not None: + padded_img = mmcv.impad_to_multiple( + results['img'], self.size_divisor, pad_val=self.pad_val) + results['img'] = padded_img + results['pad_shape'] = padded_img.shape + results['pad_fixed_size'] = self.size + results['pad_size_divisor'] = self.size_divisor + + def _pad_seg(self, results): + """Pad masks according to ``results['pad_shape']``.""" + for key in results.get('seg_fields', []): + results[key] = mmcv.impad( + results[key], + shape=results['pad_shape'][:2], + pad_val=self.seg_pad_val) + + def __call__(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + + self._pad_img(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \ + f'pad_val={self.pad_val})' + return repr_str + + +@PIPELINES.register_module() +class Normalize(object): + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + + results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, + self.to_rgb) + results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ + f'{self.to_rgb})' + return repr_str + + +@PIPELINES.register_module() +class Rerange(object): + """Rerange the image pixel value. + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def __call__(self, results): + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@PIPELINES.register_module() +class CLAHE(object): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def __call__(self, results): + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@PIPELINES.register_module() +class RandomCrop(object): + """Random crop the image & seg. + + Args: + crop_size (tuple): Expected size after cropping, (h, w). + cat_max_ratio (float): The maximum ratio that single category could + occupy. + """ + + def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def get_crop_bbox(self, img): + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def __call__(self, results): + """Call function to randomly crop images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.get_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = self.get_crop_bbox(img) + + # crop the image + img = self.crop(img, crop_bbox) + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@PIPELINES.register_module() +class RandomRotate(object): + """Rotate the image & seg. + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + def __call__(self, results): + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate = True if np.random.rand() < self.prob else False + degree = np.random.uniform(min(*self.degree), max(*self.degree)) + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@PIPELINES.register_module() +class RGB2Gray(object): + """Convert RGB image to grayscale image. + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def __call__(self, results): + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@PIPELINES.register_module() +class AdjustGamma(object): + """Using gamma correction to process the image. + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@PIPELINES.register_module() +class SegRescale(object): + """Rescale semantic segmentation maps. + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def __call__(self, results): + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@PIPELINES.register_module() +class PhotoMetricDistortion(object): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, img, alpha=1, beta=0): + """Multiple with alpha and add beat with clip.""" + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img): + """Brightness distortion.""" + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img): + """Contrast distortion.""" + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img): + """Saturation distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img): + """Hue distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def __call__(self, results): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str diff --git a/annotator/uniformer/mmseg/datasets/stare.py b/annotator/uniformer/mmseg/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/stare.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class STAREDataset(CustomDataset): + """STARE dataset. + + In segmentation map annotation for STARE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.ah.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(STAREDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.ah.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/voc.py b/annotator/uniformer/mmseg/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/voc.py @@ -0,0 +1,29 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalVOCDataset(CustomDataset): + """Pascal VOC dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + + CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', + 'train', 'tvmonitor') + + PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + def __init__(self, split, **kwargs): + super(PascalVOCDataset, self).__init__( + img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/annotator/uniformer/mmseg/models/__init__.py b/annotator/uniformer/mmseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf93f8bec9cf0cef0a3bd76ca3ca92eb188f535 --- /dev/null +++ b/annotator/uniformer/mmseg/models/__init__.py @@ -0,0 +1,12 @@ +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, + build_head, build_loss, build_segmentor) +from .decode_heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .segmentors import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', + 'build_head', 'build_loss', 'build_segmentor' +] diff --git a/annotator/uniformer/mmseg/models/backbones/__init__.py b/annotator/uniformer/mmseg/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8339983905fb5d20bae42ba6f76fea75d278b1aa --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/__init__.py @@ -0,0 +1,17 @@ +from .cgnet import CGNet +# from .fast_scnn import FastSCNN +from .hrnet import HRNet +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnext import ResNeXt +from .unet import UNet +from .vit import VisionTransformer +from .uniformer import UniFormer + +__all__ = [ + 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', + 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', + 'VisionTransformer', 'UniFormer' +] diff --git a/annotator/uniformer/mmseg/models/backbones/cgnet.py b/annotator/uniformer/mmseg/models/backbones/cgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f8bca442c8f18179f217e40c298fb5ef39df77c4 --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/cgnet.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from annotator.uniformer.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer, + constant_init, kaiming_init) +from annotator.uniformer.mmcv.runner import load_checkpoint +from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm + +from annotator.uniformer.mmseg.utils import get_root_logger +from ..builder import BACKBONES + + +class GlobalContextExtractor(nn.Module): + """Global Context Extractor for CGNet. + + This class is employed to refine the joint feature of both local feature + and surrounding context. + + Args: + channel (int): Number of input feature channels. + reduction (int): Reductions for global context extractor. Default: 16. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, channel, reduction=16, with_cp=False): + super(GlobalContextExtractor, self).__init__() + self.channel = channel + self.reduction = reduction + assert reduction >= 1 and channel >= reduction + self.with_cp = with_cp + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + + def _inner_forward(x): + num_batch, num_channel = x.size()[:2] + y = self.avg_pool(x).view(num_batch, num_channel) + y = self.fc(y).view(num_batch, num_channel, 1, 1) + return x * y + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class ContextGuidedBlock(nn.Module): + """Context Guided Block for CGNet. + + This class consists of four components: local feature extractor, + surrounding feature extractor, joint feature extractor and global + context extractor. + + Args: + in_channels (int): Number of input feature channels. + out_channels (int): Number of output feature channels. + dilation (int): Dilation rate for surrounding context extractor. + Default: 2. + reduction (int): Reduction for global context extractor. Default: 16. + skip_connect (bool): Add input to output or not. Default: True. + downsample (bool): Downsample the input to 1/2 or not. Default: False. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + dilation=2, + reduction=16, + skip_connect=True, + downsample=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + with_cp=False): + super(ContextGuidedBlock, self).__init__() + self.with_cp = with_cp + self.downsample = downsample + + channels = out_channels if downsample else out_channels // 2 + if 'type' in act_cfg and act_cfg['type'] == 'PReLU': + act_cfg['num_parameters'] = channels + kernel_size = 3 if downsample else 1 + stride = 2 if downsample else 1 + padding = (kernel_size - 1) // 2 + + self.conv1x1 = ConvModule( + in_channels, + channels, + kernel_size, + stride, + padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.f_loc = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=False) + self.f_sur = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=dilation, + groups=channels, + dilation=dilation, + bias=False) + + self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] + self.activate = nn.PReLU(2 * channels) + + if downsample: + self.bottleneck = build_conv_layer( + conv_cfg, + 2 * channels, + out_channels, + kernel_size=1, + bias=False) + + self.skip_connect = skip_connect and not downsample + self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) + + def forward(self, x): + + def _inner_forward(x): + out = self.conv1x1(x) + loc = self.f_loc(out) + sur = self.f_sur(out) + + joi_feat = torch.cat([loc, sur], 1) # the joint feature + joi_feat = self.bn(joi_feat) + joi_feat = self.activate(joi_feat) + if self.downsample: + joi_feat = self.bottleneck(joi_feat) # channel = out_channels + # f_glo is employed to refine the joint feature + out = self.f_glo(joi_feat) + + if self.skip_connect: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InputInjection(nn.Module): + """Downsampling module for CGNet.""" + + def __init__(self, num_downsampling): + super(InputInjection, self).__init__() + self.pool = nn.ModuleList() + for i in range(num_downsampling): + self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + for pool in self.pool: + x = pool(x) + return x + + +@BACKBONES.register_module() +class CGNet(nn.Module): + """CGNet backbone. + + A Light-weight Context Guided Network for Semantic Segmentation + arXiv: https://arxiv.org/abs/1811.08201 + + Args: + in_channels (int): Number of input image channels. Normally 3. + num_channels (tuple[int]): Numbers of feature channels at each stages. + Default: (32, 64, 128). + num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2. + Default: (3, 21). + dilations (tuple[int]): Dilation rate for surrounding context + extractors at stage 1 and stage 2. Default: (2, 4). + reductions (tuple[int]): Reductions for global context extractors at + stage 1 and stage 2. Default: (8, 16). + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + norm_eval=False, + with_cp=False): + + super(CGNet, self).__init__() + self.in_channels = in_channels + self.num_channels = num_channels + assert isinstance(self.num_channels, tuple) and len( + self.num_channels) == 3 + self.num_blocks = num_blocks + assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 + self.dilations = dilations + assert isinstance(self.dilations, tuple) and len(self.dilations) == 2 + self.reductions = reductions + assert isinstance(self.reductions, tuple) and len(self.reductions) == 2 + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': + self.act_cfg['num_parameters'] = num_channels[0] + self.norm_eval = norm_eval + self.with_cp = with_cp + + cur_channels = in_channels + self.stem = nn.ModuleList() + for i in range(3): + self.stem.append( + ConvModule( + cur_channels, + num_channels[0], + 3, + 2 if i == 0 else 1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + cur_channels = num_channels[0] + + self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 + self.inject_4x = InputInjection(2) # down-sample for Input, factor=4 + + cur_channels += in_channels + self.norm_prelu_0 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 1 + self.level1 = nn.ModuleList() + for i in range(num_blocks[0]): + self.level1.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[1], + num_channels[1], + dilations[0], + reductions[0], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[1] + in_channels + self.norm_prelu_1 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 2 + self.level2 = nn.ModuleList() + for i in range(num_blocks[1]): + self.level2.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[2], + num_channels[2], + dilations[1], + reductions[1], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[2] + self.norm_prelu_2 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + def forward(self, x): + output = [] + + # stage 0 + inp_2x = self.inject_2x(x) + inp_4x = self.inject_4x(x) + for layer in self.stem: + x = layer(x) + x = self.norm_prelu_0(torch.cat([x, inp_2x], 1)) + output.append(x) + + # stage 1 + for i, layer in enumerate(self.level1): + x = layer(x) + if i == 0: + down1 = x + x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1)) + output.append(x) + + # stage 2 + for i, layer in enumerate(self.level2): + x = layer(x) + if i == 0: + down2 = x + x = self.norm_prelu_2(torch.cat([down2, x], 1)) + output.append(x) + + return output + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + elif isinstance(m, nn.PReLU): + constant_init(m, 0) + else: + raise TypeError('pretrained must be a str or None') + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(CGNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/annotator/uniformer/mmseg/models/backbones/fast_scnn.py b/annotator/uniformer/mmseg/models/backbones/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..38c2350177cbc2066f45add568d30eb6041f74f3 --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/fast_scnn.py @@ -0,0 +1,375 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init, + kaiming_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from annotator.uniformer.mmseg.models.decode_heads.psp_head import PPM +from annotator.uniformer.mmseg.ops import resize +from ..builder import BACKBONES +from ..utils.inverted_residual import InvertedResidual + + +class LearningToDownsample(nn.Module): + """Learning to downsample module. + + Args: + in_channels (int): Number of input channels. + dw_channels (tuple[int]): Number of output channels of the first and + the second depthwise conv (dwconv) layers. + out_channels (int): Number of output channels of the whole + 'learning to downsample' module. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + """ + + def __init__(self, + in_channels, + dw_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')): + super(LearningToDownsample, self).__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + dw_channels1 = dw_channels[0] + dw_channels2 = dw_channels[1] + + self.conv = ConvModule( + in_channels, + dw_channels1, + 3, + stride=2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.dsconv1 = DepthwiseSeparableConvModule( + dw_channels1, + dw_channels2, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg) + self.dsconv2 = DepthwiseSeparableConvModule( + dw_channels2, + out_channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg) + + def forward(self, x): + x = self.conv(x) + x = self.dsconv1(x) + x = self.dsconv2(x) + return x + + +class GlobalFeatureExtractor(nn.Module): + """Global feature extractor module. + + Args: + in_channels (int): Number of input channels of the GFE module. + Default: 64 + block_channels (tuple[int]): Tuple of ints. Each int specifies the + number of output channels of each Inverted Residual module. + Default: (64, 96, 128) + out_channels(int): Number of output channels of the GFE module. + Default: 128 + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + Default: 6 + num_blocks (tuple[int]): Tuple of ints. Each int specifies the + number of times each Inverted Residual module is repeated. + The repeated Inverted Residual modules are called a 'group'. + Default: (3, 3, 3) + strides (tuple[int]): Tuple of ints. Each int specifies + the downsampling factor of each 'group'. + Default: (2, 2, 1) + pool_scales (tuple[int]): Tuple of ints. Each int specifies + the parameter required in 'global average pooling' within PPM. + Default: (1, 2, 3, 6) + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + expand_ratio=6, + num_blocks=(3, 3, 3), + strides=(2, 2, 1), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super(GlobalFeatureExtractor, self).__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + assert len(block_channels) == len(num_blocks) == 3 + self.bottleneck1 = self._make_layer(in_channels, block_channels[0], + num_blocks[0], strides[0], + expand_ratio) + self.bottleneck2 = self._make_layer(block_channels[0], + block_channels[1], num_blocks[1], + strides[1], expand_ratio) + self.bottleneck3 = self._make_layer(block_channels[1], + block_channels[2], num_blocks[2], + strides[2], expand_ratio) + self.ppm = PPM( + pool_scales, + block_channels[2], + block_channels[2] // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=align_corners) + self.out = ConvModule( + block_channels[2] * 2, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _make_layer(self, + in_channels, + out_channels, + blocks, + stride=1, + expand_ratio=6): + layers = [ + InvertedResidual( + in_channels, + out_channels, + stride, + expand_ratio, + norm_cfg=self.norm_cfg) + ] + for i in range(1, blocks): + layers.append( + InvertedResidual( + out_channels, + out_channels, + 1, + expand_ratio, + norm_cfg=self.norm_cfg)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.bottleneck1(x) + x = self.bottleneck2(x) + x = self.bottleneck3(x) + x = torch.cat([x, *self.ppm(x)], dim=1) + x = self.out(x) + return x + + +class FeatureFusionModule(nn.Module): + """Feature fusion module. + + Args: + higher_in_channels (int): Number of input channels of the + higher-resolution branch. + lower_in_channels (int): Number of input channels of the + lower-resolution branch. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + higher_in_channels, + lower_in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super(FeatureFusionModule, self).__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.dwconv = ConvModule( + lower_in_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.conv_lower_res = ConvModule( + out_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.conv_higher_res = ConvModule( + higher_in_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.relu = nn.ReLU(True) + + def forward(self, higher_res_feature, lower_res_feature): + lower_res_feature = resize( + lower_res_feature, + size=higher_res_feature.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + lower_res_feature = self.dwconv(lower_res_feature) + lower_res_feature = self.conv_lower_res(lower_res_feature) + + higher_res_feature = self.conv_higher_res(higher_res_feature) + out = higher_res_feature + lower_res_feature + return self.relu(out) + + +@BACKBONES.register_module() +class FastSCNN(nn.Module): + """Fast-SCNN Backbone. + + Args: + in_channels (int): Number of input image channels. Default: 3. + downsample_dw_channels (tuple[int]): Number of output channels after + the first conv layer & the second conv layer in + Learning-To-Downsample (LTD) module. + Default: (32, 48). + global_in_channels (int): Number of input channels of + Global Feature Extractor(GFE). + Equal to number of output channels of LTD. + Default: 64. + global_block_channels (tuple[int]): Tuple of integers that describe + the output channels for each of the MobileNet-v2 bottleneck + residual blocks in GFE. + Default: (64, 96, 128). + global_block_strides (tuple[int]): Tuple of integers + that describe the strides (downsampling factors) for each of the + MobileNet-v2 bottleneck residual blocks in GFE. + Default: (2, 2, 1). + global_out_channels (int): Number of output channels of GFE. + Default: 128. + higher_in_channels (int): Number of input channels of the higher + resolution branch in FFM. + Equal to global_in_channels. + Default: 64. + lower_in_channels (int): Number of input channels of the lower + resolution branch in FFM. + Equal to global_out_channels. + Default: 128. + fusion_out_channels (int): Number of output channels of FFM. + Default: 128. + out_indices (tuple): Tuple of indices of list + [higher_res_features, lower_res_features, fusion_output]. + Often set to (0,1,2) to enable aux. heads. + Default: (0, 1, 2). + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + in_channels=3, + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + + super(FastSCNN, self).__init__() + if global_in_channels != higher_in_channels: + raise AssertionError('Global Input Channels must be the same \ + with Higher Input Channels!') + elif global_out_channels != lower_in_channels: + raise AssertionError('Global Output Channels must be the same \ + with Lower Input Channels!') + + self.in_channels = in_channels + self.downsample_dw_channels1 = downsample_dw_channels[0] + self.downsample_dw_channels2 = downsample_dw_channels[1] + self.global_in_channels = global_in_channels + self.global_block_channels = global_block_channels + self.global_block_strides = global_block_strides + self.global_out_channels = global_out_channels + self.higher_in_channels = higher_in_channels + self.lower_in_channels = lower_in_channels + self.fusion_out_channels = fusion_out_channels + self.out_indices = out_indices + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.learning_to_downsample = LearningToDownsample( + in_channels, + downsample_dw_channels, + global_in_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.global_feature_extractor = GlobalFeatureExtractor( + global_in_channels, + global_block_channels, + global_out_channels, + strides=self.global_block_strides, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.feature_fusion = FeatureFusionModule( + higher_in_channels, + lower_in_channels, + fusion_out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def init_weights(self, pretrained=None): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + def forward(self, x): + higher_res_features = self.learning_to_downsample(x) + lower_res_features = self.global_feature_extractor(higher_res_features) + fusion_output = self.feature_fusion(higher_res_features, + lower_res_features) + + outs = [higher_res_features, lower_res_features, fusion_output] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/annotator/uniformer/mmseg/models/backbones/hrnet.py b/annotator/uniformer/mmseg/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..331ebf3ccb8597b3f507670753789073fc3c946d --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/hrnet.py @@ -0,0 +1,555 @@ +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, + kaiming_init) +from annotator.uniformer.mmcv.runner import load_checkpoint +from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm + +from annotator.uniformer.mmseg.ops import Upsample, resize +from annotator.uniformer.mmseg.utils import get_root_logger +from ..builder import BACKBONES +from .resnet import BasicBlock, Bottleneck + + +class HRModule(nn.Module): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True)): + super(HRModule, self).__init__() + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + """Check branches configuration.""" + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ + f'{len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ + f'{len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ + f'{len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, num_channels[branch_index] * + block.expansion)[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + """Build multiple branch.""" + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + """Build fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + # we set align_corners=False for HRNet + Upsample( + scale_factor=2**(j - i), + mode='bilinear', + align_corners=False))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + elif j > i: + y = y + resize( + self.fuse_layers[i][j](x[j]), + size=x[i].shape[2:], + mode='bilinear', + align_corners=False) + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@BACKBONES.register_module() +class HRNet(nn.Module): + """HRNet backbone. + + High-Resolution Representations for Labeling Pixels and Regions + arXiv: https://arxiv.org/abs/1904.04514 + + Args: + extra (dict): detailed configuration for each stage of HRNet. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from annotator.uniformer.mmseg.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=False): + super(HRNet, self).__init__() + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * block.expansion + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + """Make each layer.""" + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + + return nn.Sequential(*hr_modules), in_channels + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(HRNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py b/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6b3791692a0d1b5da3601875711710b7bd01ba --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py @@ -0,0 +1,180 @@ +import logging + +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init +from annotator.uniformer.mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils import InvertedResidual, make_divisible + + +@BACKBONES.register_module() +class MobileNetV2(nn.Module): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], + [6, 96, 3], [6, 160, 3], [6, 320, 1]] + + def __init__(self, + widen_factor=1., + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False): + super(MobileNetV2, self).__init__() + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py b/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..16817400b4102899794fe64c9644713a4e54e2f9 --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py @@ -0,0 +1,255 @@ +import logging + +import annotator.uniformer.mmcv as mmcv +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init +from annotator.uniformer.mmcv.cnn.bricks import Conv2dAdaptivePadding +from annotator.uniformer.mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils import InvertedResidualV3 as InvertedResidual + + +@BACKBONES.register_module() +class MobileNetV3(nn.Module): + """MobileNetV3 backbone. + + This backbone is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + arch (str): Architecture of mobilnetv3, from {'small', 'large'}. + Default: 'small'. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (tuple[int]): Output from which layer. + Default: (0, 1, 12). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 + [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 + [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False): + super(MobileNetV3, self).__init__() + assert arch in self.arch_settings + assert isinstance(reduction_factor, int) and reduction_factor > 0 + assert mmcv.is_tuple_of(out_indices, int) + for index in out_indices: + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])+2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])+2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.reduction_factor = reduction_factor + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + + # build the first layer (layer0) + in_channels = 16 + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + + if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ + i >= 8: + mid_channels = mid_channels // self.reduction_factor + out_channels = out_channels // self.reduction_factor + + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=(in_channels != mid_channels), + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # build the last layer + # block5 layer12 os=32 for small model + # block6 layer16 os=32 for large model + layer = ConvModule( + in_channels=in_channels, + out_channels=576 if self.arch == 'small' else 960, + kernel_size=1, + stride=1, + dilation=4, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = 'layer{}'.format(len(layer_setting) + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # next, convert backbone MobileNetV3 to a semantic segmentation version + if self.arch == 'small': + self.layer4.depthwise_conv.conv.stride = (1, 1) + self.layer9.depthwise_conv.conv.stride = (1, 1) + for i in range(4, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 9: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + else: + self.layer7.depthwise_conv.conv.stride = (1, 1) + self.layer13.depthwise_conv.conv.stride = (1, 1) + for i in range(7, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 13: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + + return layers + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs + + def _freeze_stages(self): + for i in range(self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV3, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/annotator/uniformer/mmseg/models/backbones/resnest.py b/annotator/uniformer/mmseg/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..b45a837f395230029e9d4194ff9f7f2f8f7067b0 --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/resnest.py @@ -0,0 +1,314 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None): + super(SplitAttentionConv2d, self).__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/annotator/uniformer/mmseg/models/backbones/resnet.py b/annotator/uniformer/mmseg/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4e52bf048d28ecb069db4728e5f05ad85ac53198 --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/resnet.py @@ -0,0 +1,688 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, + constant_init, kaiming_init) +from annotator.uniformer.mmcv.runner import load_checkpoint +from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm + +from annotator.uniformer.mmseg.utils import get_root_logger +from ..builder import BACKBONES +from ..utils import ResLayer + + +class BasicBlock(nn.Module): + """Basic block for ResNet.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None): + super(BasicBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None): + super(Bottleneck, self).__init__() + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + """Forward function for plugins.""" + out = x + for name in plugin_names: + out = getattr(self, name)(x) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNet(nn.Module): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default" 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + + - position (str, required): Position inside block to insert plugin, + options: 'after_conv1', 'after_conv2', 'after_conv3'. + + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages' + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from annotator.uniformer.mmseg.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + multi_grid=None, + contract_dilation=False, + with_cp=False, + zero_init_residual=True): + super(ResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.multi_grid = multi_grid + self.contract_dilation = contract_dilation + self.zero_init_residual = zero_init_residual + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + # multi grid is applied to last layer only + stage_multi_grid = multi_grid if i == len( + self.stage_blocks) - 1 else None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + multi_grid=stage_multi_grid, + contract_dilation=contract_dilation) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i+1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """make plugins for ResNet 'stage_idx'th stage . + + Currently we support to insert 'context_block', + 'empirical_attention_block', 'nonlocal_block' into the backbone like + ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be : + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose 'stage_idx=0', the structure of blocks in the stage would be: + conv1-> conv2->conv3->yyy->zzz1->zzz2 + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + if self.dcn is not None: + for m in self.modules(): + if isinstance(m, Bottleneck) and hasattr( + m, 'conv2_offset'): + constant_init(m.conv2_offset, 0) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1c(ResNet): + """ResNetV1c variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv + in the input stem with three 3x3 convs. + + References: + .. [1] https://arxiv.org/pdf/1812.01187.pdf + """ + + def __init__(self, **kwargs): + super(ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + +@BACKBONES.register_module() +class ResNetV1d(ResNet): + """ResNetV1d variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/annotator/uniformer/mmseg/models/backbones/resnext.py b/annotator/uniformer/mmseg/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..962249ad6fd9b50960ad6426f7ce3cac6ed8c5bc --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/resnext.py @@ -0,0 +1,145 @@ +import math + +from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@BACKBONES.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + num_stages (int): Resnet stages, normally 4. + groups (int): Group of resnext. + base_width (int): Base width of resnext. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from annotator.uniformer.mmseg.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super(ResNeXt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/annotator/uniformer/mmseg/models/backbones/unet.py b/annotator/uniformer/mmseg/models/backbones/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..82caa16a94c195c192a2a920fb7bc7e60f0f3ce3 --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/unet.py @@ -0,0 +1,429 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, + build_norm_layer, constant_init, kaiming_init) +from annotator.uniformer.mmcv.runner import load_checkpoint +from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm + +from annotator.uniformer.mmseg.utils import get_root_logger +from ..builder import BACKBONES +from ..utils import UpConvBlock + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super(BasicConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super(DeconvModule, self).__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + norm_name, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super(InterpConv, self).__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = nn.Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@BACKBONES.register_module() +class UNet(nn.Module): + """UNet backbone. + U-Net: Convolutional Networks for Biomedical Image Segmentation. + https://arxiv.org/pdf/1505.04597.pdf + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None): + super(UNet, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append((nn.Sequential(*enc_conv_block))) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(UNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be divisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') diff --git a/annotator/uniformer/mmseg/models/backbones/uniformer.py b/annotator/uniformer/mmseg/models/backbones/uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4bb88e4c928540cca9ab609988b916520f5b7a --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/uniformer.py @@ -0,0 +1,422 @@ +# -------------------------------------------------------- +# UniFormer +# Copyright (c) 2022 SenseTime X-Lab +# Licensed under The MIT License [see LICENSE for details] +# Written by Kunchang Li +# -------------------------------------------------------- + +from collections import OrderedDict +import math + +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from annotator.uniformer.mmcv_custom import load_checkpoint +from annotator.uniformer.mmseg.utils import get_root_logger +from ..builder import BACKBONES + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = nn.BatchNorm2d(dim) + self.conv1 = nn.Conv2d(dim, dim, 1) + self.conv2 = nn.Conv2d(dim, dim, 1) + self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.BatchNorm2d(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x))))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SABlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + B, N, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).reshape(B, N, H, W) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SABlock_Windows(nn.Module): + def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.window_size=window_size + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + x = x.permute(0, 2, 3, 1) + B, H, W, C = x.shape + shortcut = x + x = self.norm1(x) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.permute(0, 3, 1, 2).reshape(B, C, H, W) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.norm = nn.LayerNorm(embed_dim) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, _, H, W = x.shape + x = self.proj(x) + B, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + return x + + +@BACKBONES.register_module() +class UniFormer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512], + head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0], + windows=False, hybrid=False, window_size=14): + """ + Args: + layer (list): number of block in each layer + img_size (int, tuple): input image size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + head_dim (int): dimension of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer (nn.Module): normalization layer + pretrained_path (str): path of pretrained model + use_checkpoint (bool): whether use checkpoint + checkpoint_num (list): index for using checkpoint in every stage + windows (bool): whether use window MHRA + hybrid (bool): whether use hybrid MHRA + window_size (int): size of window (>14) + """ + super().__init__() + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.checkpoint_num = checkpoint_num + self.windows = windows + print(f'Use Checkpoint: {self.use_checkpoint}') + print(f'Checkpoint Number: {self.checkpoint_num}') + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0]) + self.patch_embed2 = PatchEmbed( + img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1]) + self.patch_embed3 = PatchEmbed( + img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2]) + self.patch_embed4 = PatchEmbed( + img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3]) + + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule + num_heads = [dim // head_dim for dim in embed_dim] + self.blocks1 = nn.ModuleList([ + CBlock( + dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(layers[0])]) + self.norm1=norm_layer(embed_dim[0]) + self.blocks2 = nn.ModuleList([ + CBlock( + dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer) + for i in range(layers[1])]) + self.norm2 = norm_layer(embed_dim[1]) + if self.windows: + print('Use local window for all blocks in stage3') + self.blocks3 = nn.ModuleList([ + SABlock_Windows( + dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer) + for i in range(layers[2])]) + elif hybrid: + print('Use hybrid window for blocks in stage3') + block3 = [] + for i in range(layers[2]): + if (i + 1) % 4 == 0: + block3.append(SABlock( + dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)) + else: + block3.append(SABlock_Windows( + dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)) + self.blocks3 = nn.ModuleList(block3) + else: + print('Use global window for all blocks in stage3') + self.blocks3 = nn.ModuleList([ + SABlock( + dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer) + for i in range(layers[2])]) + self.norm3 = norm_layer(embed_dim[2]) + self.blocks4 = nn.ModuleList([ + SABlock( + dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer) + for i in range(layers[3])]) + self.norm4 = norm_layer(embed_dim[3]) + + # Representation layer + if representation_size: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + self.apply(self._init_weights) + self.init_weights(pretrained=pretrained_path) + + def init_weights(self, pretrained): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + print(f'Load pretrained model from {pretrained}') + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + out = [] + x = self.patch_embed1(x) + x = self.pos_drop(x) + for i, blk in enumerate(self.blocks1): + if self.use_checkpoint and i < self.checkpoint_num[0]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm1(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + x = self.patch_embed2(x) + for i, blk in enumerate(self.blocks2): + if self.use_checkpoint and i < self.checkpoint_num[1]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm2(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + x = self.patch_embed3(x) + for i, blk in enumerate(self.blocks3): + if self.use_checkpoint and i < self.checkpoint_num[2]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm3(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + x = self.patch_embed4(x) + for i, blk in enumerate(self.blocks4): + if self.use_checkpoint and i < self.checkpoint_num[3]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm4(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + return tuple(out) + + def forward(self, x): + x = self.forward_features(x) + return x diff --git a/annotator/uniformer/mmseg/models/backbones/vit.py b/annotator/uniformer/mmseg/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..59e4479650690e08cbc4cab9427aefda47c2116d --- /dev/null +++ b/annotator/uniformer/mmseg/models/backbones/vit.py @@ -0,0 +1,459 @@ +"""Modified from https://github.com/rwightman/pytorch-image- +models/blob/master/timm/models/vision_transformer.py.""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from annotator.uniformer.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer, + constant_init, kaiming_init, normal_init) +from annotator.uniformer.mmcv.runner import _load_checkpoint +from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm + +from annotator.uniformer.mmseg.utils import get_root_logger +from ..builder import BACKBONES +from ..utils import DropPath, trunc_normal_ + + +class Mlp(nn.Module): + """MLP layer for Encoder block. + + Args: + in_features(int): Input dimension for the first fully + connected layer. + hidden_features(int): Output dimension for the first fully + connected layer. + out_features(int): Output dementsion for the second fully + connected layer. + act_cfg(dict): Config dict for activation layer. + Default: dict(type='GELU'). + drop(float): Drop rate for the dropout layer. Dropout rate has + to be between 0 and 1. Default: 0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super(Mlp, self).__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = Linear(in_features, hidden_features) + self.act = build_activation_layer(act_cfg) + self.fc2 = Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + """Attention layer for Encoder block. + + Args: + dim (int): Dimension for the input vector. + num_heads (int): Number of parallel attention heads. + qkv_bias (bool): Enable bias for qkv if True. Default: False. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + attn_drop (float): Drop rate for attention output weights. + Default: 0. + proj_drop (float): Drop rate for output weights. Default: 0. + """ + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super(Attention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, + c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + """Implements encoder block with residual connection. + + Args: + dim (int): The feature dimension. + num_heads (int): Number of parallel attention heads. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop (float): Drop rate for mlp output weights. Default: 0. + attn_drop (float): Drop rate for attention output weights. + Default: 0. + proj_drop (float): Drop rate for attn layer output weights. + Default: 0. + drop_path (float): Drop rate for paths of model. + Default: 0. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN', requires_grad=True). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + dim, + num_heads, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + proj_drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + with_cp=False): + super(Block, self).__init__() + self.with_cp = with_cp + _, self.norm1 = build_norm_layer(norm_cfg, dim) + self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, + proj_drop) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + _, self.norm2 = build_norm_layer(norm_cfg, dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + def forward(self, x): + + def _inner_forward(x): + out = x + self.drop_path(self.attn(self.norm1(x))) + out = out + self.drop_path(self.mlp(self.norm2(out))) + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding. + + Args: + img_size (int | tuple): Input image size. + default: 224. + patch_size (int): Width and height for a patch. + default: 16. + in_channels (int): Input channels for images. Default: 3. + embed_dim (int): The embedding dimension. Default: 768. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dim=768): + super(PatchEmbed, self).__init__() + if isinstance(img_size, int): + self.img_size = (img_size, img_size) + elif isinstance(img_size, tuple): + self.img_size = img_size + else: + raise TypeError('img_size must be type of int or tuple') + h, w = self.img_size + self.patch_size = (patch_size, patch_size) + self.num_patches = (h // patch_size) * (w // patch_size) + self.proj = Conv2d( + in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + return self.proj(x).flatten(2).transpose(1, 2) + + +@BACKBONES.register_module() +class VisionTransformer(nn.Module): + """Vision transformer backbone. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for + Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 + + Args: + img_size (tuple): input image size. Default: (224, 224). + patch_size (int, tuple): patch size. Default: 16. + in_channels (int): number of input channels. Default: 3. + embed_dim (int): embedding dimension. Default: 768. + depth (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qkv_bias (bool): enable bias for qkv if True. Default: True. + qk_scale (float): override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): dropout rate. Default: 0. + attn_drop_rate (float): attention dropout rate. Default: 0. + drop_path_rate (float): Rate of DropPath. Default: 0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN', eps=1e-6, requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='GELU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + with_cls_token (bool): If concatenating class token into image tokens + as transformer input. Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + """ + + def __init__(self, + img_size=(224, 224), + patch_size=16, + in_channels=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + out_indices=11, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True), + act_cfg=dict(type='GELU'), + norm_eval=False, + final_norm=False, + with_cls_token=True, + interpolate_mode='bicubic', + with_cp=False): + super(VisionTransformer, self).__init__() + self.img_size = img_size + self.patch_size = patch_size + self.features = self.embed_dim = embed_dim + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim) + + self.with_cls_token = with_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=dpr[i], + attn_drop=attn_drop_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp) for i in range(depth) + ]) + + self.interpolate_mode = interpolate_mode + self.final_norm = final_norm + if final_norm: + _, self.norm = build_norm_layer(norm_cfg, embed_dim) + + self.norm_eval = norm_eval + self.with_cp = with_cp + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + checkpoint = _load_checkpoint(pretrained, logger=logger) + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + logger.info(msg=f'Resize the pos_embed shape from \ +{state_dict["pos_embed"].shape} to {self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], (h, w), (pos_size, pos_size), + self.patch_size, self.interpolate_mode) + + self.load_state_dict(state_dict, False) + + elif pretrained is None: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'mlp' in n: + normal_init(m.bias, std=1e-6) + else: + constant_init(m.bias, 0) + elif isinstance(m, Conv2d): + kaiming_init(m.weight, mode='fan_in') + if m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m.bias, 0) + constant_init(m.weight, 1.0) + else: + raise TypeError('pretrained must be a str or None') + + def _pos_embeding(self, img, patched_img, pos_embed): + """Positiong embeding method. + + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + img (torch.Tensor): The inference image tensor, the shape + must be [B, C, H, W]. + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size) + 1: + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], + (pos_h, pos_w), self.patch_size, + self.interpolate_mode) + return self.pos_drop(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): pos_embed weights. + input_shpae (tuple): Tuple for (input_h, intput_w). + pos_shape (tuple): Tuple for (pos_h, pos_w). + patch_size (int): Patch size. + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + input_h, input_w = input_shpae + pos_h, pos_w = pos_shape + cls_token_weight = pos_embed[:, 0] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = F.interpolate( + pos_embed_weight, + size=[input_h // patch_size, input_w // patch_size], + align_corners=False, + mode=mode) + cls_token_weight = cls_token_weight.unsqueeze(1) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, inputs): + B = inputs.shape[0] + + x = self.patch_embed(inputs) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self._pos_embeding(inputs, x, self.pos_embed) + + if not self.with_cls_token: + # Remove class token for transformer input + x = x[:, 1:] + + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if i == len(self.blocks) - 1: + if self.final_norm: + x = self.norm(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, inputs.shape[2] // self.patch_size, + inputs.shape[3] // self.patch_size, + C).permute(0, 3, 1, 2) + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super(VisionTransformer, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/annotator/uniformer/mmseg/models/builder.py b/annotator/uniformer/mmseg/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5b971252bfc971c3ffbaa27746d69b1d3ea9fd --- /dev/null +++ b/annotator/uniformer/mmseg/models/builder.py @@ -0,0 +1,46 @@ +import warnings + +from annotator.uniformer.mmcv.cnn import MODELS as MMCV_MODELS +from annotator.uniformer.mmcv.utils import Registry + +MODELS = Registry('models', parent=MMCV_MODELS) + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/annotator/uniformer/mmseg/models/decode_heads/__init__.py b/annotator/uniformer/mmseg/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac66d3cfe0ea04af45c0f3594bf135841c3812e3 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/__init__.py @@ -0,0 +1,28 @@ +from .ann_head import ANNHead +from .apc_head import APCHead +from .aspp_head import ASPPHead +from .cc_head import CCHead +from .da_head import DAHead +from .dm_head import DMHead +from .dnl_head import DNLHead +from .ema_head import EMAHead +from .enc_head import EncHead +from .fcn_head import FCNHead +from .fpn_head import FPNHead +from .gc_head import GCHead +from .lraspp_head import LRASPPHead +from .nl_head import NLHead +from .ocr_head import OCRHead +# from .point_head import PointHead +from .psa_head import PSAHead +from .psp_head import PSPHead +from .sep_aspp_head import DepthwiseSeparableASPPHead +from .sep_fcn_head import DepthwiseSeparableFCNHead +from .uper_head import UPerHead + +__all__ = [ + 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', + 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', + 'APCHead', 'DMHead', 'LRASPPHead' +] diff --git a/annotator/uniformer/mmseg/models/decode_heads/ann_head.py b/annotator/uniformer/mmseg/models/decode_heads/ann_head.py new file mode 100644 index 0000000000000000000000000000000000000000..30aaacc2cafc568d3de71d1477b4de0dc0fea9d3 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/ann_head.py @@ -0,0 +1,245 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PPMConcat(nn.ModuleList): + """Pyramid Pooling Module that only concat the features of each layer. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + """ + + def __init__(self, pool_scales=(1, 3, 6, 8)): + super(PPMConcat, self).__init__( + [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) + + def forward(self, feats): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(feats) + ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) + concat_outs = torch.cat(ppm_outs, dim=2) + return concat_outs + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Make a ANN used SelfAttentionBlock. + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_scale (int): The scale of query feature map. + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, share_key_query, query_scale, key_pool_scales, + conv_cfg, norm_cfg, act_cfg): + key_psp = PPMConcat(key_pool_scales) + if query_scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=query_scale) + else: + query_downsample = None + super(SelfAttentionBlock, self).__init__( + key_in_channels=low_in_channels, + query_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=share_key_query, + query_downsample=query_downsample, + key_downsample=key_psp, + key_query_num_convs=1, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + +class AFNB(nn.Module): + """Asymmetric Fusion Non-local Block(AFNB) + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + and query projection. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, query_scales, key_pool_scales, conv_cfg, + norm_cfg, act_cfg): + super(AFNB, self).__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=False, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + out_channels + high_in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, low_feats, high_feats): + """Forward function.""" + priors = [stage(high_feats, low_feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, high_feats], 1)) + return output + + +class APNB(nn.Module): + """Asymmetric Pyramid Non-local Block (APNB) + + Args: + in_channels (int): Input channels of key/query feature, + which is the key feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, out_channels, query_scales, + key_pool_scales, conv_cfg, norm_cfg, act_cfg): + super(APNB, self).__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=in_channels, + high_in_channels=in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=True, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + 2 * in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, feats): + """Forward function.""" + priors = [stage(feats, feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, feats], 1)) + return output + + +@HEADS.register_module() +class ANNHead(BaseDecodeHead): + """Asymmetric Non-local Neural Networks for Semantic Segmentation. + + This head is the implementation of `ANNNet + `_. + + Args: + project_channels (int): Projection channels for Nonlocal. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): The pooling scales of key feature map. + Default: (1, 3, 6, 8). + """ + + def __init__(self, + project_channels, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + **kwargs): + super(ANNHead, self).__init__( + input_transform='multiple_select', **kwargs) + assert len(self.in_channels) == 2 + low_in_channels, high_in_channels = self.in_channels + self.project_channels = project_channels + self.fusion = AFNB( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + out_channels=high_in_channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + high_in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.context = APNB( + in_channels=self.channels, + out_channels=self.channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + low_feats, high_feats = self._transform_inputs(inputs) + output = self.fusion(low_feats, high_feats) + output = self.dropout(output) + output = self.bottleneck(output) + output = self.context(output) + output = self.cls_seg(output) + + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/apc_head.py b/annotator/uniformer/mmseg/models/decode_heads/apc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c7038bdbe0edf2a1f184b6899486d2d190dda076 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/apc_head.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class ACM(nn.Module): + """Adaptive Context Module used in APCNet. + + Args: + pool_scale (int): Pooling scale used in Adaptive Context + Module to extract region features. + fusion (bool): Add one conv to fuse residual feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super(ACM, self).__init__() + self.pool_scale = pool_scale + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.pooled_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.global_info = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) + + self.residual_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) + # [batch_size, channels, h, w] + x = self.input_redu_conv(x) + # [batch_size, channels, pool_scale, pool_scale] + pooled_x = self.pooled_redu_conv(pooled_x) + batch_size = x.size(0) + # [batch_size, pool_scale * pool_scale, channels] + pooled_x = pooled_x.view(batch_size, self.channels, + -1).permute(0, 2, 1).contiguous() + # [batch_size, h * w, pool_scale * pool_scale] + affinity_matrix = self.gla(x + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) + ).permute(0, 2, 3, 1).reshape( + batch_size, -1, self.pool_scale**2) + affinity_matrix = F.sigmoid(affinity_matrix) + # [batch_size, h * w, channels] + z_out = torch.matmul(affinity_matrix, pooled_x) + # [batch_size, channels, h * w] + z_out = z_out.permute(0, 2, 1).contiguous() + # [batch_size, channels, h, w] + z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) + z_out = self.residual_conv(z_out) + z_out = F.relu(z_out + x) + if self.fusion: + z_out = self.fusion_conv(z_out) + + return z_out + + +@HEADS.register_module() +class APCHead(BaseDecodeHead): + """Adaptive Pyramid Context Network for Semantic Segmentation. + + This head is the implementation of + `APCNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Adaptive Context + Module. Default: (1, 2, 3, 6). + fusion (bool): Add one conv to fuse residual feature. + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): + super(APCHead, self).__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.fusion = fusion + acm_modules = [] + for pool_scale in self.pool_scales: + acm_modules.append( + ACM(pool_scale, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.acm_modules = nn.ModuleList(acm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + acm_outs = [x] + for acm_module in self.acm_modules: + acm_outs.append(acm_module(x)) + acm_outs = torch.cat(acm_outs, dim=1) + output = self.bottleneck(acm_outs) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..aa914b5bb25124d1ff199553d96713d6a80484c0 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class ASPPModule(nn.ModuleList): + """Atrous Spatial Pyramid Pooling (ASPP) Module. + + Args: + dilations (tuple[int]): Dilation rate of each layer. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, + act_cfg): + super(ASPPModule, self).__init__() + self.dilations = dilations + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for dilation in dilations: + self.append( + ConvModule( + self.in_channels, + self.channels, + 1 if dilation == 1 else 3, + dilation=dilation, + padding=0 if dilation == 1 else dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x): + """Forward function.""" + aspp_outs = [] + for aspp_module in self: + aspp_outs.append(aspp_module(x)) + + return aspp_outs + + +@HEADS.register_module() +class ASPPHead(BaseDecodeHead): + """Rethinking Atrous Convolution for Semantic Image Segmentation. + + This head is the implementation of `DeepLabV3 + `_. + + Args: + dilations (tuple[int]): Dilation rates for ASPP module. + Default: (1, 6, 12, 18). + """ + + def __init__(self, dilations=(1, 6, 12, 18), **kwargs): + super(ASPPHead, self).__init__(**kwargs) + assert isinstance(dilations, (list, tuple)) + self.dilations = dilations + self.image_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.aspp_modules = ASPPModule( + dilations, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + (len(dilations) + 1) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py b/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d02122ca0e68743b1bf7a893afae96042f23838c --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py @@ -0,0 +1,57 @@ +from abc import ABCMeta, abstractmethod + +from .decode_head import BaseDecodeHead + + +class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): + """Base class for cascade decode head used in + :class:`CascadeEncoderDecoder.""" + + def __init__(self, *args, **kwargs): + super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) + + @abstractmethod + def forward(self, inputs, prev_output): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, + train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs, prev_output) + losses = self.losses(seg_logits, gt_semantic_seg) + + return losses + + def forward_test(self, inputs, prev_output, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs, prev_output) diff --git a/annotator/uniformer/mmseg/models/decode_heads/cc_head.py b/annotator/uniformer/mmseg/models/decode_heads/cc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9abb4e747f92657f4220b29788539340986c00 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/cc_head.py @@ -0,0 +1,42 @@ +import torch + +from ..builder import HEADS +from .fcn_head import FCNHead + +try: + from annotator.uniformer.mmcv.ops import CrissCrossAttention +except ModuleNotFoundError: + CrissCrossAttention = None + + +@HEADS.register_module() +class CCHead(FCNHead): + """CCNet: Criss-Cross Attention for Semantic Segmentation. + + This head is the implementation of `CCNet + `_. + + Args: + recurrence (int): Number of recurrence of Criss Cross Attention + module. Default: 2. + """ + + def __init__(self, recurrence=2, **kwargs): + if CrissCrossAttention is None: + raise RuntimeError('Please install mmcv-full for ' + 'CrissCrossAttention ops') + super(CCHead, self).__init__(num_convs=2, **kwargs) + self.recurrence = recurrence + self.cca = CrissCrossAttention(self.channels) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + for _ in range(self.recurrence): + output = self.cca(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/da_head.py b/annotator/uniformer/mmseg/models/decode_heads/da_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd49fcfdc7c0a70f9485cc71843dcf3e0cb1774 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/da_head.py @@ -0,0 +1,178 @@ +import torch +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule, Scale +from torch import nn + +from annotator.uniformer.mmseg.core import add_prefix +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PAM(_SelfAttentionBlock): + """Position Attention Module (PAM) + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + """ + + def __init__(self, in_channels, channels): + super(PAM, self).__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=1, + key_query_norm=False, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=False, + with_out=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + out = super(PAM, self).forward(x, x) + + out = self.gamma(out) + x + return out + + +class CAM(nn.Module): + """Channel Attention Module (CAM)""" + + def __init__(self): + super(CAM, self).__init__() + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + batch_size, channels, height, width = x.size() + proj_query = x.view(batch_size, channels, -1) + proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max( + energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = F.softmax(energy_new, dim=-1) + proj_value = x.view(batch_size, channels, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(batch_size, channels, height, width) + + out = self.gamma(out) + x + return out + + +@HEADS.register_module() +class DAHead(BaseDecodeHead): + """Dual Attention Network for Scene Segmentation. + + This head is the implementation of `DANet + `_. + + Args: + pam_channels (int): The channels of Position Attention Module(PAM). + """ + + def __init__(self, pam_channels, **kwargs): + super(DAHead, self).__init__(**kwargs) + self.pam_channels = pam_channels + self.pam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam = PAM(self.channels, pam_channels) + self.pam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + self.cam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam = CAM() + self.cam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + def pam_cls_seg(self, feat): + """PAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.pam_conv_seg(feat) + return output + + def cam_cls_seg(self, feat): + """CAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.cam_conv_seg(feat) + return output + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + pam_feat = self.pam_in_conv(x) + pam_feat = self.pam(pam_feat) + pam_feat = self.pam_out_conv(pam_feat) + pam_out = self.pam_cls_seg(pam_feat) + + cam_feat = self.cam_in_conv(x) + cam_feat = self.cam(cam_feat) + cam_feat = self.cam_out_conv(cam_feat) + cam_out = self.cam_cls_seg(cam_feat) + + feat_sum = pam_feat + cam_feat + pam_cam_out = self.cls_seg(feat_sum) + + return pam_cam_out, pam_out, cam_out + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing, only ``pam_cam`` is used.""" + return self.forward(inputs)[0] + + def losses(self, seg_logit, seg_label): + """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" + pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit + loss = dict() + loss.update( + add_prefix( + super(DAHead, self).losses(pam_cam_seg_logit, seg_label), + 'pam_cam')) + loss.update( + add_prefix( + super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) + loss.update( + add_prefix( + super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) + return loss diff --git a/annotator/uniformer/mmseg/models/decode_heads/decode_head.py b/annotator/uniformer/mmseg/models/decode_heads/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..88a661b8f6fec5d4c031d3d85e80777ee63951a6 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/decode_head.py @@ -0,0 +1,234 @@ +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import normal_init +from annotator.uniformer.mmcv.runner import auto_fp16, force_fp32 + +from annotator.uniformer.mmseg.core import build_pixel_sampler +from annotator.uniformer.mmseg.ops import resize +from ..builder import build_loss +from ..losses import accuracy + + +class BaseDecodeHead(nn.Module, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255 + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False): + super(BaseDecodeHead, self).__init__() + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.num_classes = num_classes + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + self.loss_decode = build_loss(loss_decode) + self.ignore_index = ignore_index + self.align_corners = align_corners + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def init_weights(self): + """Initialize weights of classification layer.""" + normal_init(self.conv_seg, mean=0, std=0.01) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @auto_fp16() + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.losses(seg_logits, gt_semantic_seg) + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs) + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + @force_fp32(apply_to=('seg_logit', )) + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + loss = dict() + seg_logit = resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + loss['loss_seg'] = self.loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + loss['acc_seg'] = accuracy(seg_logit, seg_label) + return loss diff --git a/annotator/uniformer/mmseg/models/decode_heads/dm_head.py b/annotator/uniformer/mmseg/models/decode_heads/dm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..19c963923126b53ce22f60813540a35badf24b3d --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/dm_head.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class DCM(nn.Module): + """Dynamic Convolutional Module used in DMNet. + + Args: + filter_size (int): The filter size of generated convolution kernel + used in Dynamic Convolutional Module. + fusion (bool): Add one conv to fuse DCM output feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super(DCM, self).__init__() + self.filter_size = filter_size + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, + 0) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.norm_cfg is not None: + self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] + else: + self.norm = None + self.activate = build_activation_layer(self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + generated_filter = self.filter_gen_conv( + F.adaptive_avg_pool2d(x, self.filter_size)) + x = self.input_redu_conv(x) + b, c, h, w = x.shape + # [1, b * c, h, w], c = self.channels + x = x.view(1, b * c, h, w) + # [b * c, 1, filter_size, filter_size] + generated_filter = generated_filter.view(b * c, 1, self.filter_size, + self.filter_size) + pad = (self.filter_size - 1) // 2 + if (self.filter_size - 1) % 2 == 0: + p2d = (pad, pad, pad, pad) + else: + p2d = (pad + 1, pad, pad + 1, pad) + x = F.pad(input=x, pad=p2d, mode='constant', value=0) + # [1, b * c, h, w] + output = F.conv2d(input=x, weight=generated_filter, groups=b * c) + # [b, c, h, w] + output = output.view(b, c, h, w) + if self.norm is not None: + output = self.norm(output) + output = self.activate(output) + + if self.fusion: + output = self.fusion_conv(output) + + return output + + +@HEADS.register_module() +class DMHead(BaseDecodeHead): + """Dynamic Multi-scale Filters for Semantic Segmentation. + + This head is the implementation of + `DMNet `_. + + Args: + filter_sizes (tuple[int]): The size of generated convolutional filters + used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). + fusion (bool): Add one conv to fuse DCM output feature. + """ + + def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): + super(DMHead, self).__init__(**kwargs) + assert isinstance(filter_sizes, (list, tuple)) + self.filter_sizes = filter_sizes + self.fusion = fusion + dcm_modules = [] + for filter_size in self.filter_sizes: + dcm_modules.append( + DCM(filter_size, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.dcm_modules = nn.ModuleList(dcm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(filter_sizes) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + dcm_outs = [x] + for dcm_module in self.dcm_modules: + dcm_outs.append(dcm_module(x)) + dcm_outs = torch.cat(dcm_outs, dim=1) + output = self.bottleneck(dcm_outs) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py b/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..333280c5947066fd3c7ebcfe302a0e7ad65480d5 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py @@ -0,0 +1,131 @@ +import torch +from annotator.uniformer.mmcv.cnn import NonLocal2d +from torch import nn + +from ..builder import HEADS +from .fcn_head import FCNHead + + +class DisentangledNonLocal2d(NonLocal2d): + """Disentangled Non-Local Blocks. + + Args: + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, *arg, temperature, **kwargs): + super().__init__(*arg, **kwargs) + self.temperature = temperature + self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) + + def embedded_gaussian(self, theta_x, phi_x): + """Embedded gaussian with temperature.""" + + # NonLocal2d pairwise_weight: [N, HxW, HxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= theta_x.shape[-1]**0.5 + pairwise_weight /= self.temperature + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def forward(self, x): + # x: [N, C, H, W] + n = x.size(0) + + # g_x: [N, HxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta_x: [N, HxW, C], phi_x: [N, C, HxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + # subtract mean + theta_x -= theta_x.mean(dim=-2, keepdim=True) + phi_x -= phi_x.mean(dim=-1, keepdim=True) + + pairwise_func = getattr(self, self.mode) + # pairwise_weight: [N, HxW, HxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # y: [N, HxW, C] + y = torch.matmul(pairwise_weight, g_x) + # y: [N, C, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + # unary_mask: [N, 1, HxW] + unary_mask = self.conv_mask(x) + unary_mask = unary_mask.view(n, 1, -1) + unary_mask = unary_mask.softmax(dim=-1) + # unary_x: [N, 1, C] + unary_x = torch.matmul(unary_mask, g_x) + # unary_x: [N, C, 1, 1] + unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( + n, self.inter_channels, 1, 1) + + output = x + self.conv_out(y + unary_x) + + return output + + +@HEADS.register_module() +class DNLHead(FCNHead): + """Disentangled Non-Local Neural Networks. + + This head is the implementation of `DNLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: False. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + temperature=0.05, + **kwargs): + super(DNLHead, self).__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.temperature = temperature + self.dnl_block = DisentangledNonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode, + temperature=self.temperature) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.dnl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/ema_head.py b/annotator/uniformer/mmseg/models/decode_heads/ema_head.py new file mode 100644 index 0000000000000000000000000000000000000000..12267cb40569d2b5a4a2955a6dc2671377ff5e0a --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/ema_head.py @@ -0,0 +1,168 @@ +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +def reduce_mean(tensor): + """Reduce mean when distributed training.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +class EMAModule(nn.Module): + """Expectation Maximization Attention Module used in EMANet. + + Args: + channels (int): Channels of the whole module. + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + """ + + def __init__(self, channels, num_bases, num_stages, momentum): + super(EMAModule, self).__init__() + assert num_stages >= 1, 'num_stages must be at least 1!' + self.num_bases = num_bases + self.num_stages = num_stages + self.momentum = momentum + + bases = torch.zeros(1, channels, self.num_bases) + bases.normal_(0, math.sqrt(2. / self.num_bases)) + # [1, channels, num_bases] + bases = F.normalize(bases, dim=1, p=2) + self.register_buffer('bases', bases) + + def forward(self, feats): + """Forward function.""" + batch_size, channels, height, width = feats.size() + # [batch_size, channels, height*width] + feats = feats.view(batch_size, channels, height * width) + # [batch_size, channels, num_bases] + bases = self.bases.repeat(batch_size, 1, 1) + + with torch.no_grad(): + for i in range(self.num_stages): + # [batch_size, height*width, num_bases] + attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = F.softmax(attention, dim=2) + # l1 norm + attention_normed = F.normalize(attention, dim=1, p=1) + # [batch_size, channels, num_bases] + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = feats_recon.view(batch_size, channels, height, width) + + if self.training: + bases = bases.mean(dim=0, keepdim=True) + bases = reduce_mean(bases) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + self.bases = (1 - + self.momentum) * self.bases + self.momentum * bases + + return feats_recon + + +@HEADS.register_module() +class EMAHead(BaseDecodeHead): + """Expectation Maximization Attention Networks for Semantic Segmentation. + + This head is the implementation of `EMANet + `_. + + Args: + ema_channels (int): EMA module channels + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + concat_input (bool): Whether concat the input and output of convs + before classification layer. Default: True + momentum (float): Momentum to update the base. Default: 0.1. + """ + + def __init__(self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs): + super(EMAHead, self).__init__(**kwargs) + self.ema_channels = ema_channels + self.num_bases = num_bases + self.num_stages = num_stages + self.concat_input = concat_input + self.momentum = momentum + self.ema_module = EMAModule(self.ema_channels, self.num_bases, + self.num_stages, self.momentum) + + self.ema_in_conv = ConvModule( + self.in_channels, + self.ema_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # project (0, inf) -> (-inf, inf) + self.ema_mid_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None) + for param in self.ema_mid_conv.parameters(): + param.requires_grad = False + + self.ema_out_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.bottleneck = ConvModule( + self.ema_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.ema_in_conv(x) + identity = feats + feats = self.ema_mid_conv(feats) + recon = self.ema_module(feats) + recon = F.relu(recon, inplace=True) + recon = self.ema_out_conv(recon) + output = F.relu(identity + recon, inplace=True) + output = self.bottleneck(output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/enc_head.py b/annotator/uniformer/mmseg/models/decode_heads/enc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..da57af617e05d41761628fd2d6d232655b32d905 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/enc_head.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule, build_norm_layer + +from annotator.uniformer.mmseg.ops import Encoding, resize +from ..builder import HEADS, build_loss +from .decode_head import BaseDecodeHead + + +class EncModule(nn.Module): + """Encoding Module used in EncNet. + + Args: + in_channels (int): Input channels. + num_codes (int): Number of code words. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): + super(EncModule, self).__init__() + self.encoding_project = ConvModule( + in_channels, + in_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # TODO: resolve this hack + # change to 1d + if norm_cfg is not None: + encoding_norm_cfg = norm_cfg.copy() + if encoding_norm_cfg['type'] in ['BN', 'IN']: + encoding_norm_cfg['type'] += '1d' + else: + encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( + '2d', '1d') + else: + # fallback to BN1d + encoding_norm_cfg = dict(type='BN1d') + self.encoding = nn.Sequential( + Encoding(channels=in_channels, num_codes=num_codes), + build_norm_layer(encoding_norm_cfg, num_codes)[1], + nn.ReLU(inplace=True)) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels), nn.Sigmoid()) + + def forward(self, x): + """Forward function.""" + encoding_projection = self.encoding_project(x) + encoding_feat = self.encoding(encoding_projection).mean(dim=1) + batch_size, channels, _, _ = x.size() + gamma = self.fc(encoding_feat) + y = gamma.view(batch_size, channels, 1, 1) + output = F.relu_(x + x * y) + return encoding_feat, output + + +@HEADS.register_module() +class EncHead(BaseDecodeHead): + """Context Encoding for Semantic Segmentation. + + This head is the implementation of `EncNet + `_. + + Args: + num_codes (int): Number of code words. Default: 32. + use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to + regularize the training. Default: True. + add_lateral (bool): Whether use lateral connection to fuse features. + Default: False. + loss_se_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss', use_sigmoid=True). + """ + + def __init__(self, + num_codes=32, + use_se_loss=True, + add_lateral=False, + loss_se_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=0.2), + **kwargs): + super(EncHead, self).__init__( + input_transform='multiple_select', **kwargs) + self.use_se_loss = use_se_loss + self.add_lateral = add_lateral + self.num_codes = num_codes + self.bottleneck = ConvModule( + self.in_channels[-1], + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if add_lateral: + self.lateral_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the last one + self.lateral_convs.append( + ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.fusion = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.enc_module = EncModule( + self.channels, + num_codes=num_codes, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.use_se_loss: + self.loss_se_decode = build_loss(loss_se_decode) + self.se_layer = nn.Linear(self.channels, self.num_classes) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + feat = self.bottleneck(inputs[-1]) + if self.add_lateral: + laterals = [ + resize( + lateral_conv(inputs[i]), + size=feat.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + feat = self.fusion(torch.cat([feat, *laterals], 1)) + encode_feat, output = self.enc_module(feat) + output = self.cls_seg(output) + if self.use_se_loss: + se_output = self.se_layer(encode_feat) + return output, se_output + else: + return output + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing, ignore se_loss.""" + if self.use_se_loss: + return self.forward(inputs)[0] + else: + return self.forward(inputs) + + @staticmethod + def _convert_to_onehot_labels(seg_label, num_classes): + """Convert segmentation label to onehot. + + Args: + seg_label (Tensor): Segmentation label of shape (N, H, W). + num_classes (int): Number of classes. + + Returns: + Tensor: Onehot labels of shape (N, num_classes). + """ + + batch_size = seg_label.size(0) + onehot_labels = seg_label.new_zeros((batch_size, num_classes)) + for i in range(batch_size): + hist = seg_label[i].float().histc( + bins=num_classes, min=0, max=num_classes - 1) + onehot_labels[i] = hist > 0 + return onehot_labels + + def losses(self, seg_logit, seg_label): + """Compute segmentation and semantic encoding loss.""" + seg_logit, se_seg_logit = seg_logit + loss = dict() + loss.update(super(EncHead, self).losses(seg_logit, seg_label)) + se_loss = self.loss_se_decode( + se_seg_logit, + self._convert_to_onehot_labels(seg_label, self.num_classes)) + loss['loss_se'] = se_loss + return loss diff --git a/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py b/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..edb32c283fa4baada6b4a0bf3f7540c3580c3468 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class FCNHead(BaseDecodeHead): + """Fully Convolution Networks for Semantic Segmentation. + + This head is implemented of `FCNNet `_. + + Args: + num_convs (int): Number of convs in the head. Default: 2. + kernel_size (int): The kernel size for convs in the head. Default: 3. + concat_input (bool): Whether concat the input and output of convs + before classification layer. + dilation (int): The dilation rate for convs in the head. Default: 1. + """ + + def __init__(self, + num_convs=2, + kernel_size=3, + concat_input=True, + dilation=1, + **kwargs): + assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) + self.num_convs = num_convs + self.concat_input = concat_input + self.kernel_size = kernel_size + super(FCNHead, self).__init__(**kwargs) + if num_convs == 0: + assert self.in_channels == self.channels + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + ConvModule( + self.in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + for i in range(num_convs - 1): + convs.append( + ConvModule( + self.channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs(x) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py b/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1241c55b0813d1ecdddf1e66e7c5031fbf78ed50 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py @@ -0,0 +1,68 @@ +import numpy as np +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class FPNHead(BaseDecodeHead): + """Panoptic Feature Pyramid Networks. + + This head is the implementation of `Semantic FPN + `_. + + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, feature_strides, **kwargs): + super(FPNHead, self).__init__( + input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.in_channels[i] if k == 0 else self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + nn.Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/gc_head.py b/annotator/uniformer/mmseg/models/decode_heads/gc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..70741245af975800840709911bd18d72247e3e04 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/gc_head.py @@ -0,0 +1,47 @@ +import torch +from annotator.uniformer.mmcv.cnn import ContextBlock + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class GCHead(FCNHead): + """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. + + This head is the implementation of `GCNet + `_. + + Args: + ratio (float): Multiplier of channels ratio. Default: 1/4. + pooling_type (str): The pooling type of context aggregation. + Options are 'att', 'avg'. Default: 'avg'. + fusion_types (tuple[str]): The fusion type for feature fusion. + Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) + """ + + def __init__(self, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + **kwargs): + super(GCHead, self).__init__(num_convs=2, **kwargs) + self.ratio = ratio + self.pooling_type = pooling_type + self.fusion_types = fusion_types + self.gc_block = ContextBlock( + in_channels=self.channels, + ratio=self.ratio, + pooling_type=self.pooling_type, + fusion_types=self.fusion_types) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.gc_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..69bf320934d787aaa11984a0c4effe9ad8015b22 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv import is_tuple_of +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class LRASPPHead(BaseDecodeHead): + """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. + + This head is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + branch_channels (tuple[int]): The number of output channels in every + each branch. Default: (32, 64). + """ + + def __init__(self, branch_channels=(32, 64), **kwargs): + super(LRASPPHead, self).__init__(**kwargs) + if self.input_transform != 'multiple_select': + raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' + f'must be \'multiple_select\'. But received ' + f'\'{self.input_transform}\'') + assert is_tuple_of(branch_channels, int) + assert len(branch_channels) == len(self.in_channels) - 1 + self.branch_channels = branch_channels + + self.convs = nn.Sequential() + self.conv_ups = nn.Sequential() + for i in range(len(branch_channels)): + self.convs.add_module( + f'conv{i}', + nn.Conv2d( + self.in_channels[i], branch_channels[i], 1, bias=False)) + self.conv_ups.add_module( + f'conv_up{i}', + ConvModule( + self.channels + branch_channels[i], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False)) + + self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) + + self.aspp_conv = ConvModule( + self.in_channels[-1], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False) + self.image_pool = nn.Sequential( + nn.AvgPool2d(kernel_size=49, stride=(16, 20)), + ConvModule( + self.in_channels[2], + self.channels, + 1, + act_cfg=dict(type='Sigmoid'), + bias=False)) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + x = inputs[-1] + + x = self.aspp_conv(x) * resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = self.conv_up_input(x) + + for i in range(len(self.branch_channels) - 1, -1, -1): + x = resize( + x, + size=inputs[i].size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = torch.cat([x, self.convs[i](inputs[i])], 1) + x = self.conv_ups[i](x) + + return self.cls_seg(x) diff --git a/annotator/uniformer/mmseg/models/decode_heads/nl_head.py b/annotator/uniformer/mmseg/models/decode_heads/nl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3eee424199e6aa363b564e2a3340a070db04db86 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/nl_head.py @@ -0,0 +1,49 @@ +import torch +from annotator.uniformer.mmcv.cnn import NonLocal2d + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class NLHead(FCNHead): + """Non-local Neural Networks. + + This head is the implementation of `NLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: True. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + **kwargs): + super(NLHead, self).__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.nl_block = NonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.nl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py b/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..715852e94e81dc46623972748285d2d19237a341 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .cascade_decode_head import BaseCascadeDecodeHead + + +class SpatialGatherModule(nn.Module): + """Aggregate the context features according to the initial predicted + probability distribution. + + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, scale): + super(SpatialGatherModule, self).__init__() + self.scale = scale + + def forward(self, feats, probs): + """Forward function.""" + batch_size, num_classes, height, width = probs.size() + channels = feats.size(1) + probs = probs.view(batch_size, num_classes, -1) + feats = feats.view(batch_size, channels, -1) + # [batch_size, height*width, num_classes] + feats = feats.permute(0, 2, 1) + # [batch_size, channels, height*width] + probs = F.softmax(self.scale * probs, dim=2) + # [batch_size, channels, num_classes] + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(_SelfAttentionBlock): + """Make a OCR used SelfAttentionBlock.""" + + def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, + act_cfg): + if scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=scale) + else: + query_downsample = None + super(ObjectAttentionBlock, self).__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=query_downsample, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.bottleneck = ConvModule( + in_channels * 2, + in_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, query_feats, key_feats): + """Forward function.""" + context = super(ObjectAttentionBlock, + self).forward(query_feats, key_feats) + output = self.bottleneck(torch.cat([context, query_feats], dim=1)) + if self.query_downsample is not None: + output = resize(query_feats) + + return output + + +@HEADS.register_module() +class OCRHead(BaseCascadeDecodeHead): + """Object-Contextual Representations for Semantic Segmentation. + + This head is the implementation of `OCRNet + `_. + + Args: + ocr_channels (int): The intermediate channels of OCR block. + scale (int): The scale of probability map in SpatialGatherModule in + Default: 1. + """ + + def __init__(self, ocr_channels, scale=1, **kwargs): + super(OCRHead, self).__init__(**kwargs) + self.ocr_channels = ocr_channels + self.scale = scale + self.object_context_block = ObjectAttentionBlock( + self.channels, + self.ocr_channels, + self.scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.spatial_gather_module = SpatialGatherModule(self.scale) + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs, prev_output): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.bottleneck(x) + context = self.spatial_gather_module(feats, prev_output) + object_context = self.object_context_block(feats, context) + output = self.cls_seg(object_context) + + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/point_head.py b/annotator/uniformer/mmseg/models/decode_heads/point_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3342aa28bb8d264b2c3d01cbf5098d145943c193 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/point_head.py @@ -0,0 +1,349 @@ +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa + +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule, normal_init +from annotator.uniformer.mmcv.ops import point_sample + +from annotator.uniformer.mmseg.models.builder import HEADS +from annotator.uniformer.mmseg.ops import resize +from ..losses import accuracy +from .cascade_decode_head import BaseCascadeDecodeHead + + +def calculate_uncertainty(seg_logits): + """Estimate uncertainty based on seg logits. + + For each location of the prediction ``seg_logits`` we estimate + uncertainty as the difference between top first and top second + predicted logits. + + Args: + seg_logits (Tensor): Semantic segmentation logits, + shape (batch_size, num_classes, height, width). + + Returns: + scores (Tensor): T uncertainty scores with the most uncertain + locations having the highest uncertainty score, shape ( + batch_size, 1, height, width) + """ + top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +@HEADS.register_module() +class PointHead(BaseCascadeDecodeHead): + """A mask point head use in PointRend. + + ``PointHead`` use shared multi-layer perceptron (equivalent to + nn.Conv1d) to predict the logit of input points. The fine-grained feature + and coarse feature will be concatenate together for predication. + + Args: + num_fcs (int): Number of fc layers in the head. Default: 3. + in_channels (int): Number of input channels. Default: 256. + fc_channels (int): Number of fc channels. Default: 256. + num_classes (int): Number of classes for logits. Default: 80. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Default: False. + coarse_pred_each_layer (bool): Whether concatenate coarse feature with + the output of each fc layer. Default: True. + conv_cfg (dict|None): Dictionary to construct and config conv layer. + Default: dict(type='Conv1d')) + norm_cfg (dict|None): Dictionary to construct and config norm layer. + Default: None. + loss_point (dict): Dictionary to construct and config loss layer of + point head. Default: dict(type='CrossEntropyLoss', use_mask=True, + loss_weight=1.0). + """ + + def __init__(self, + num_fcs=3, + coarse_pred_each_layer=True, + conv_cfg=dict(type='Conv1d'), + norm_cfg=None, + act_cfg=dict(type='ReLU', inplace=False), + **kwargs): + super(PointHead, self).__init__( + input_transform='multiple_select', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + + self.num_fcs = num_fcs + self.coarse_pred_each_layer = coarse_pred_each_layer + + fc_in_channels = sum(self.in_channels) + self.num_classes + fc_channels = self.channels + self.fcs = nn.ModuleList() + for k in range(num_fcs): + fc = ConvModule( + fc_in_channels, + fc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.fcs.append(fc) + fc_in_channels = fc_channels + fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ + else 0 + self.fc_seg = nn.Conv1d( + fc_in_channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0) + if self.dropout_ratio > 0: + self.dropout = nn.Dropout(self.dropout_ratio) + delattr(self, 'conv_seg') + + def init_weights(self): + """Initialize weights of classification layer.""" + normal_init(self.fc_seg, std=0.001) + + def cls_seg(self, feat): + """Classify each pixel with fc.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.fc_seg(feat) + return output + + def forward(self, fine_grained_point_feats, coarse_point_feats): + x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) + for fc in self.fcs: + x = fc(x) + if self.coarse_pred_each_layer: + x = torch.cat((x, coarse_point_feats), dim=1) + return self.cls_seg(x) + + def _get_fine_grained_point_feats(self, x, points): + """Sample from fine grained features. + + Args: + x (list[Tensor]): Feature pyramid from by neck or backbone. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + fine_grained_feats (Tensor): Sampled fine grained feature, + shape (batch_size, sum(channels of x), num_points). + """ + + fine_grained_feats_list = [ + point_sample(_, points, align_corners=self.align_corners) + for _ in x + ] + if len(fine_grained_feats_list) > 1: + fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) + else: + fine_grained_feats = fine_grained_feats_list[0] + + return fine_grained_feats + + def _get_coarse_point_feats(self, prev_output, points): + """Sample from fine grained features. + + Args: + prev_output (list[Tensor]): Prediction of previous decode head. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, + num_classes, num_points). + """ + + coarse_feats = point_sample( + prev_output, points, align_corners=self.align_corners) + + return coarse_feats + + def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, + train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self._transform_inputs(inputs) + with torch.no_grad(): + points = self.get_points_train( + prev_output, calculate_uncertainty, cfg=train_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats(prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + point_label = point_sample( + gt_semantic_seg.float(), + points, + mode='nearest', + align_corners=self.align_corners) + point_label = point_label.squeeze(1).long() + + losses = self.losses(point_logits, point_label) + + return losses + + def forward_test(self, inputs, prev_output, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + + x = self._transform_inputs(inputs) + refined_seg_logits = prev_output.clone() + for _ in range(test_cfg.subdivision_steps): + refined_seg_logits = resize( + refined_seg_logits, + scale_factor=test_cfg.scale_factor, + mode='bilinear', + align_corners=self.align_corners) + batch_size, channels, height, width = refined_seg_logits.shape + point_indices, points = self.get_points_test( + refined_seg_logits, calculate_uncertainty, cfg=test_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats( + prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) + refined_seg_logits = refined_seg_logits.reshape( + batch_size, channels, height * width) + refined_seg_logits = refined_seg_logits.scatter_( + 2, point_indices, point_logits) + refined_seg_logits = refined_seg_logits.view( + batch_size, channels, height, width) + + return refined_seg_logits + + def losses(self, point_logits, point_label): + """Compute segmentation loss.""" + loss = dict() + loss['loss_point'] = self.loss_decode( + point_logits, point_label, ignore_index=self.ignore_index) + loss['acc_point'] = accuracy(point_logits, point_label) + return loss + + def get_points_train(self, seg_logits, uncertainty_func, cfg): + """Sample points for training. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'uncertainty_func' function that takes point's logit prediction as + input. + + Args: + seg_logits (Tensor): Semantic segmentation logits, shape ( + batch_size, num_classes, height, width). + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Training config of point head. + + Returns: + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains the coordinates of ``num_points`` sampled + points. + """ + num_points = cfg.num_points + oversample_ratio = cfg.oversample_ratio + importance_sample_ratio = cfg.importance_sample_ratio + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = seg_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=seg_logits.device) + point_logits = point_sample(seg_logits, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=seg_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_point_coords = torch.rand( + batch_size, num_random_points, 2, device=seg_logits.device) + point_coords = torch.cat((point_coords, rand_point_coords), dim=1) + return point_coords + + def get_points_test(self, seg_logits, uncertainty_func, cfg): + """Sample points for testing. + + Find ``num_points`` most uncertain points from ``uncertainty_map``. + + Args: + seg_logits (Tensor): A tensor of shape (batch_size, num_classes, + height, width) for class-specific or class-agnostic prediction. + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Testing config of point head. + + Returns: + point_indices (Tensor): A tensor of shape (batch_size, num_points) + that contains indices from [0, height x width) of the most + uncertain points. + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the ``height x width`` grid . + """ + + num_points = cfg.subdivision_num_points + uncertainty_map = uncertainty_func(seg_logits) + batch_size, _, height, width = uncertainty_map.shape + h_step = 1.0 / height + w_step = 1.0 / width + + uncertainty_map = uncertainty_map.view(batch_size, height * width) + num_points = min(height * width, num_points) + point_indices = uncertainty_map.topk(num_points, dim=1)[1] + point_coords = torch.zeros( + batch_size, + num_points, + 2, + dtype=torch.float, + device=seg_logits.device) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % + width).float() * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // + width).float() * h_step + return point_indices, point_coords diff --git a/annotator/uniformer/mmseg/models/decode_heads/psa_head.py b/annotator/uniformer/mmseg/models/decode_heads/psa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..480dbd1a081262e45bf87e32c4a339ac8f8b4ffb --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/psa_head.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + +try: + from annotator.uniformer.mmcv.ops import PSAMask +except ModuleNotFoundError: + PSAMask = None + + +@HEADS.register_module() +class PSAHead(BaseDecodeHead): + """Point-wise Spatial Attention Network for Scene Parsing. + + This head is the implementation of `PSANet + `_. + + Args: + mask_size (tuple[int]): The PSA mask size. It usually equals input + size. + psa_type (str): The type of psa module. Options are 'collect', + 'distribute', 'bi-direction'. Default: 'bi-direction' + compact (bool): Whether use compact map for 'collect' mode. + Default: True. + shrink_factor (int): The downsample factors of psa mask. Default: 2. + normalization_factor (float): The normalize factor of attention. + psa_softmax (bool): Whether use softmax for attention. + """ + + def __init__(self, + mask_size, + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + **kwargs): + if PSAMask is None: + raise RuntimeError('Please install mmcv-full for PSAMask ops') + super(PSAHead, self).__init__(**kwargs) + assert psa_type in ['collect', 'distribute', 'bi-direction'] + self.psa_type = psa_type + self.compact = compact + self.shrink_factor = shrink_factor + self.mask_size = mask_size + mask_h, mask_w = mask_size + self.psa_softmax = psa_softmax + if normalization_factor is None: + normalization_factor = mask_h * mask_w + self.normalization_factor = normalization_factor + + self.reduce = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + if psa_type == 'bi-direction': + self.reduce_p = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention_p = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + self.psamask_collect = PSAMask('collect', mask_size) + self.psamask_distribute = PSAMask('distribute', mask_size) + else: + self.psamask = PSAMask(psa_type, mask_size) + self.proj = ConvModule( + self.channels * (2 if psa_type == 'bi-direction' else 1), + self.in_channels, + kernel_size=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + self.in_channels * 2, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + identity = x + align_corners = self.align_corners + if self.psa_type in ['collect', 'distribute']: + out = self.reduce(x) + n, c, h, w = out.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + out = resize( + out, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y = self.attention(out) + if self.compact: + if self.psa_type == 'collect': + y = y.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y = self.psamask(y) + if self.psa_softmax: + y = F.softmax(y, dim=1) + out = torch.bmm( + out.view(n, c, h * w), y.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + else: + x_col = self.reduce(x) + x_dis = self.reduce_p(x) + n, c, h, w = x_col.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + x_col = resize( + x_col, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + x_dis = resize( + x_dis, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y_col = self.attention(x_col) + y_dis = self.attention_p(x_dis) + if self.compact: + y_dis = y_dis.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y_col = self.psamask_collect(y_col) + y_dis = self.psamask_distribute(y_dis) + if self.psa_softmax: + y_col = F.softmax(y_col, dim=1) + y_dis = F.softmax(y_dis, dim=1) + x_col = torch.bmm( + x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + x_dis = torch.bmm( + x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + out = torch.cat([x_col, x_dis], 1) + out = self.proj(out) + out = resize( + out, + size=identity.shape[2:], + mode='bilinear', + align_corners=align_corners) + out = self.bottleneck(torch.cat((identity, out), dim=1)) + out = self.cls_seg(out) + return out diff --git a/annotator/uniformer/mmseg/models/decode_heads/psp_head.py b/annotator/uniformer/mmseg/models/decode_heads/psp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f1e71c70c3a20f4007c263ec471a87bb214a48 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/psp_head.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class PPM(nn.ModuleList): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, + act_cfg, align_corners): + super(PPM, self).__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg))) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = resize( + ppm_out, + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +@HEADS.register_module() +class PSPHead(BaseDecodeHead): + """Pyramid Scene Parsing Network. + + This head is the implementation of + `PSPNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super(PSPHead, self).__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.psp_modules = PPM( + self.pool_scales, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3339a7ac56e77dfc638e9bffb557d4699148686b --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .aspp_head import ASPPHead, ASPPModule + + +class DepthwiseSeparableASPPModule(ASPPModule): + """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable + conv.""" + + def __init__(self, **kwargs): + super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) + for i, dilation in enumerate(self.dilations): + if dilation > 1: + self[i] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + 3, + dilation=dilation, + padding=dilation, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + +@HEADS.register_module() +class DepthwiseSeparableASPPHead(ASPPHead): + """Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation. + + This head is the implementation of `DeepLabV3+ + `_. + + Args: + c1_in_channels (int): The input channels of c1 decoder. If is 0, + the no decoder will be used. + c1_channels (int): The intermediate channels of c1 decoder. + """ + + def __init__(self, c1_in_channels, c1_channels, **kwargs): + super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) + assert c1_in_channels >= 0 + self.aspp_modules = DepthwiseSeparableASPPModule( + dilations=self.dilations, + in_channels=self.in_channels, + channels=self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if c1_in_channels > 0: + self.c1_bottleneck = ConvModule( + c1_in_channels, + c1_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + else: + self.c1_bottleneck = None + self.sep_bottleneck = nn.Sequential( + DepthwiseSeparableConvModule( + self.channels + c1_channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + DepthwiseSeparableConvModule( + self.channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + if self.c1_bottleneck is not None: + c1_output = self.c1_bottleneck(inputs[0]) + output = resize( + input=output, + size=c1_output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = torch.cat([output, c1_output], dim=1) + output = self.sep_bottleneck(output) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py b/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a0986143fa4f2bd36f5271354fe5f843f35b9e6f --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py @@ -0,0 +1,51 @@ +from annotator.uniformer.mmcv.cnn import DepthwiseSeparableConvModule + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class DepthwiseSeparableFCNHead(FCNHead): + """Depthwise-Separable Fully Convolutional Network for Semantic + Segmentation. + + This head is implemented according to Fast-SCNN paper. + Args: + in_channels(int): Number of output channels of FFM. + channels(int): Number of middle-stage channels in the decode head. + concat_input(bool): Whether to concatenate original decode input into + the result of several consecutive convolution layers. + Default: True. + num_classes(int): Used to determine the dimension of + final prediction tensor. + in_index(int): Correspond with 'out_indices' in FastSCNN backbone. + norm_cfg (dict | None): Config of norm layers. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_decode(dict): Config of loss type and some + relevant additional options. + """ + + def __init__(self, **kwargs): + super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) + self.convs[0] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg) + for i in range(1, self.num_convs): + self.convs[i] = DepthwiseSeparableConvModule( + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg) + + if self.concat_input: + self.conv_cat = DepthwiseSeparableConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg) diff --git a/annotator/uniformer/mmseg/models/decode_heads/uper_head.py b/annotator/uniformer/mmseg/models/decode_heads/uper_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1301b706b0d83ed714bbdee8ee24693f150455 --- /dev/null +++ b/annotator/uniformer/mmseg/models/decode_heads/uper_head.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from annotator.uniformer.mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead +from .psp_head import PPM + + +@HEADS.register_module() +class UPerHead(BaseDecodeHead): + """Unified Perceptual Parsing for Scene Understanding. + + This head is the implementation of `UPerNet + `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module applied on the last feature. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super(UPerHead, self).__init__( + input_transform='multiple_select', **kwargs) + # PSP Module + self.psp_modules = PPM( + pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, inputs): + """Forward function.""" + + inputs = self._transform_inputs(inputs) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += resize( + laterals[i], + size=prev_shape, + mode='bilinear', + align_corners=self.align_corners) + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.cls_seg(output) + return output diff --git a/annotator/uniformer/mmseg/models/losses/__init__.py b/annotator/uniformer/mmseg/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..beca72045694273d63465bac2f27dbc6672271db --- /dev/null +++ b/annotator/uniformer/mmseg/models/losses/__init__.py @@ -0,0 +1,12 @@ +from .accuracy import Accuracy, accuracy +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss +from .lovasz_loss import LovaszLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss' +] diff --git a/annotator/uniformer/mmseg/models/losses/accuracy.py b/annotator/uniformer/mmseg/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fd2e7e74a0f721c4a814c09d6e453e5956bb38 --- /dev/null +++ b/annotator/uniformer/mmseg/models/losses/accuracy.py @@ -0,0 +1,78 @@ +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / target.numel())) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh) diff --git a/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py b/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..42c0790c98616bb69621deed55547fc04c7392ef --- /dev/null +++ b/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py @@ -0,0 +1,198 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100): + """The wrapper function for :func:`F.cross_entropy`""" + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights *= valid_mask + + return bin_labels, bin_label_weights + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=255): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int | None): The label index to be ignored. Default: 255 + + Returns: + torch.Tensor: The calculated loss + """ + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + label, weight = _expand_onehot_labels(label, weight, pred.shape, + ignore_index) + + # weighted element-wise losses + if weight is not None: + weight = weight.float() + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@LOSSES.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/annotator/uniformer/mmseg/models/losses/dice_loss.py b/annotator/uniformer/mmseg/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..27a77b962d7d8b3079c7d6cd9db52280c6fb4970 --- /dev/null +++ b/annotator/uniformer/mmseg/models/losses/dice_loss.py @@ -0,0 +1,119 @@ +"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ +segmentron/solver/loss.py (Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def dice_loss(pred, + target, + valid_mask, + smooth=1, + exponent=2, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + dice_loss = binary_dice_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + smooth=smooth, + exponent=exponent) + if class_weight is not None: + dice_loss *= class_weight[i] + total_loss += dice_loss + return total_loss / num_classes + + +@weighted_loss +def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth + den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth + + return 1 - num / den + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + """DiceLoss. + + This loss is proposed in `V-Net: Fully Convolutional Neural Networks for + Volumetric Medical Image Segmentation `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1 + exponent (float): An float number to calculate denominator + value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + """ + + def __init__(self, + smooth=1, + exponent=2, + reduction='mean', + class_weight=None, + loss_weight=1.0, + ignore_index=255, + **kwards): + super(DiceLoss, self).__init__() + self.smooth = smooth + self.exponent = exponent + self.reduction = reduction + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + + def forward(self, + pred, + target, + avg_factor=None, + reduction_override=None, + **kwards): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * dice_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor, + smooth=self.smooth, + exponent=self.exponent, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss diff --git a/annotator/uniformer/mmseg/models/losses/lovasz_loss.py b/annotator/uniformer/mmseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6badb67f6d987b59fb07aa97caaaf89896e27a8d --- /dev/null +++ b/annotator/uniformer/mmseg/models/losses/lovasz_loss.py @@ -0,0 +1,303 @@ +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import annotator.uniformer.mmcv as mmcv +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@LOSSES.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0): + super(LovaszLoss, self).__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or mmcv.is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/annotator/uniformer/mmseg/models/losses/utils.py b/annotator/uniformer/mmseg/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85aec9f3045240c3de96a928324ae8f5c3aebe8b --- /dev/null +++ b/annotator/uniformer/mmseg/models/losses/utils.py @@ -0,0 +1,121 @@ +import functools + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch.nn.functional as F + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = mmcv.load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Avarage factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/annotator/uniformer/mmseg/models/necks/__init__.py b/annotator/uniformer/mmseg/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9d3d5b3fe80247642d962edd6fb787537d01d6 --- /dev/null +++ b/annotator/uniformer/mmseg/models/necks/__init__.py @@ -0,0 +1,4 @@ +from .fpn import FPN +from .multilevel_neck import MultiLevelNeck + +__all__ = ['FPN', 'MultiLevelNeck'] diff --git a/annotator/uniformer/mmseg/models/necks/fpn.py b/annotator/uniformer/mmseg/models/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..a53b2a69500f8c2edb835abc3ff0ccc2173d1fb1 --- /dev/null +++ b/annotator/uniformer/mmseg/models/necks/fpn.py @@ -0,0 +1,212 @@ +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule, xavier_init + +from ..builder import NECKS + + +@NECKS.register_module() +class FPN(nn.Module): + """Feature Pyramid Network. + + This is an implementation of - Feature Pyramid Networks for Object + Detection (https://arxiv.org/abs/1612.03144) + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(mode='nearest')` + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest')): + super(FPN, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] += F.interpolate(laterals[i], + **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/annotator/uniformer/mmseg/models/necks/multilevel_neck.py b/annotator/uniformer/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..766144d8136326a1fab5906a153a0c0df69b6b60 --- /dev/null +++ b/annotator/uniformer/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,70 @@ +import torch.nn as nn +import torch.nn.functional as F +from annotator.uniformer.mmcv.cnn import ConvModule + +from ..builder import NECKS + + +@NECKS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[int]): Scale factors for each input feature map. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super(MultiLevelNeck, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + print(inputs[0].shape) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = F.interpolate( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/annotator/uniformer/mmseg/models/segmentors/__init__.py b/annotator/uniformer/mmseg/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dca2f09405330743c476e190896bee39c45498ea --- /dev/null +++ b/annotator/uniformer/mmseg/models/segmentors/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseSegmentor +from .cascade_encoder_decoder import CascadeEncoderDecoder +from .encoder_decoder import EncoderDecoder + +__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] diff --git a/annotator/uniformer/mmseg/models/segmentors/base.py b/annotator/uniformer/mmseg/models/segmentors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..172fc63b736c4f13be1cd909433bc260760a1eaa --- /dev/null +++ b/annotator/uniformer/mmseg/models/segmentors/base.py @@ -0,0 +1,273 @@ +import logging +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from annotator.uniformer.mmcv.runner import auto_fp16 + + +class BaseSegmentor(nn.Module): + """Base class for segmentors.""" + + __metaclass__ = ABCMeta + + def __init__(self): + super(BaseSegmentor, self).__init__() + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def init_weights(self, pretrained=None): + """Initialize the weights in segmentor. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if pretrained is not None: + logger = logging.getLogger() + logger.info(f'load model from: {pretrained}') + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got ' + f'{type(var)}') + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(imgs)}) != ' + f'num of image meta ({len(img_metas)})') + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_['ori_shape'] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_['img_shape'] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_['pad_shape'] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data_batch['img_metas'])) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def show_result(self, + img, + result, + palette=None, + win_name='', + show=False, + wait_time=0, + out_file=None, + opacity=0.5): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (Tensor): The semantic segmentation results to draw over + `img`. + palette (list[list[int]]] | np.ndarray | None): The palette of + segmentation map. If None is given, random palette will be + generated. Default: None + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + img = mmcv.imread(img) + img = img.copy() + seg = result[0] + if palette is None: + if self.PALETTE is None: + palette = np.random.randint( + 0, 255, size=(len(self.CLASSES), 3)) + else: + palette = self.PALETTE + palette = np.array(palette) + assert palette.shape[0] == len(self.CLASSES) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img diff --git a/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py b/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..873957d8d6468147c994493d92ff5c1b15bfb703 --- /dev/null +++ b/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -0,0 +1,98 @@ +from torch import nn + +from annotator.uniformer.mmseg.core import add_prefix +from annotator.uniformer.mmseg.ops import resize +from .. import builder +from ..builder import SEGMENTORS +from .encoder_decoder import EncoderDecoder + + +@SEGMENTORS.register_module() +class CascadeEncoderDecoder(EncoderDecoder): + """Cascade Encoder Decoder segmentors. + + CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of + CascadeEncoderDecoder are cascaded. The output of previous decoder_head + will be the input of next decoder_head. + """ + + def __init__(self, + num_stages, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + self.num_stages = num_stages + super(CascadeEncoderDecoder, self).__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained) + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + assert isinstance(decode_head, list) + assert len(decode_head) == self.num_stages + self.decode_head = nn.ModuleList() + for i in range(self.num_stages): + self.decode_head.append(builder.build_head(decode_head[i])) + self.align_corners = self.decode_head[-1].align_corners + self.num_classes = self.decode_head[-1].num_classes + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone and heads. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + self.backbone.init_weights(pretrained=pretrained) + for i in range(self.num_stages): + self.decode_head[i].init_weights() + if self.with_auxiliary_head: + if isinstance(self.auxiliary_head, nn.ModuleList): + for aux_head in self.auxiliary_head: + aux_head.init_weights() + else: + self.auxiliary_head.init_weights() + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) + for i in range(1, self.num_stages): + out = self.decode_head[i].forward_test(x, out, img_metas, + self.test_cfg) + out = resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + + loss_decode = self.decode_head[0].forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode_0')) + + for i in range(1, self.num_stages): + # forward test again, maybe unnecessary for most methods. + prev_outputs = self.decode_head[i - 1].forward_test( + x, img_metas, self.test_cfg) + loss_decode = self.decode_head[i].forward_train( + x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_decode, f'decode_{i}')) + + return losses diff --git a/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py b/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..98392ac04c4c44a7f4e7b1c0808266875877dd1f --- /dev/null +++ b/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py @@ -0,0 +1,298 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmseg.core import add_prefix +from annotator.uniformer.mmseg.ops import resize +from .. import builder +from ..builder import SEGMENTORS +from .base import BaseSegmentor + + +@SEGMENTORS.register_module() +class EncoderDecoder(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(EncoderDecoder, self).__init__() + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.init_weights(pretrained=pretrained) + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone and heads. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + super(EncoderDecoder, self).init_weights(pretrained) + self.backbone.init_weights(pretrained=pretrained) + self.decode_head.init_weights() + if self.with_auxiliary_head: + if isinstance(self.auxiliary_head, nn.ModuleList): + for aux_head in self.auxiliary_head: + aux_head.init_weights() + else: + self.auxiliary_head.init_weights() + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, + gt_semantic_seg) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train( + x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + # TODO refactor + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + preds = resize( + preds, + size=img_meta[0]['ori_shape'][:2], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]['ori_shape'][:2] + seg_logit = resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]['flip'] + if flip: + flip_direction = img_meta[0]['flip_direction'] + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + output = output.flip(dims=(3, )) + elif flip_direction == 'vertical': + output = output.flip(dims=(2, )) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/annotator/uniformer/mmseg/models/utils/__init__.py b/annotator/uniformer/mmseg/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/__init__.py @@ -0,0 +1,13 @@ +from .drop import DropPath +from .inverted_residual import InvertedResidual, InvertedResidualV3 +from .make_divisible import make_divisible +from .res_layer import ResLayer +from .se_layer import SELayer +from .self_attention_block import SelfAttentionBlock +from .up_conv_block import UpConvBlock +from .weight_init import trunc_normal_ + +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_' +] diff --git a/annotator/uniformer/mmseg/models/utils/drop.py b/annotator/uniformer/mmseg/models/utils/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..4520b0ff407d2a95a864086bdbca0065f222aa63 --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/drop.py @@ -0,0 +1,31 @@ +"""Modified from https://github.com/rwightman/pytorch-image- +models/blob/master/timm/models/layers/drop.py.""" + +import torch +from torch import nn + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + + Args: + drop_prob (float): Drop rate for paths of model. Dropout rate has + to be between 0 and 1. Default: 0. + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.keep_prob = 1 - drop_prob + + def forward(self, x): + if self.drop_prob == 0. or not self.training: + return x + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = self.keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(self.keep_prob) * random_tensor + return output diff --git a/annotator/uniformer/mmseg/models/utils/inverted_residual.py b/annotator/uniformer/mmseg/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..53b8fcd41f71d814738f1ac3f5acd3c3d701bf96 --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/inverted_residual.py @@ -0,0 +1,208 @@ +from annotator.uniformer.mmcv.cnn import ConvModule +from torch import nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(InvertedResidualV3, self).__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/annotator/uniformer/mmseg/models/utils/make_divisible.py b/annotator/uniformer/mmseg/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078 --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/make_divisible.py @@ -0,0 +1,27 @@ +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/annotator/uniformer/mmseg/models/utils/res_layer.py b/annotator/uniformer/mmseg/models/utils/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c07b47007e92e4c3945b989e79f9d50306f5fe --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/res_layer.py @@ -0,0 +1,94 @@ +from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer +from torch import nn as nn + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + multi_grid (int | None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + multi_grid=None, + contract_dilation=False, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if multi_grid is None: + if dilation > 1 and contract_dilation: + first_dilation = dilation // 2 + else: + first_dilation = dilation + else: + first_dilation = multi_grid[0] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=first_dilation, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation if multi_grid is None else multi_grid[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super(ResLayer, self).__init__(*layers) diff --git a/annotator/uniformer/mmseg/models/utils/se_layer.py b/annotator/uniformer/mmseg/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..083bd7d1ccee909c900c7aed2cc928bf14727f3e --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/se_layer.py @@ -0,0 +1,57 @@ +import annotator.uniformer.mmcv as mmcv +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule + +from .make_divisible import make_divisible + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configured + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configured by the first dict and the + second activation layer will be configured by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)). + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))): + super(SELayer, self).__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmcv.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=make_divisible(channels // ratio, 8), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=make_divisible(channels // ratio, 8), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/annotator/uniformer/mmseg/models/utils/self_attention_block.py b/annotator/uniformer/mmseg/models/utils/self_attention_block.py new file mode 100644 index 0000000000000000000000000000000000000000..440c7b73ee4706fde555595926d63a18d7574acc --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/self_attention_block.py @@ -0,0 +1,159 @@ +import torch +from annotator.uniformer.mmcv.cnn import ConvModule, constant_init +from torch import nn as nn +from torch.nn import functional as F + + +class SelfAttentionBlock(nn.Module): + """General self-attention block/non-local block. + + Please refer to https://arxiv.org/abs/1706.03762 for details about key, + query and value. + + Args: + key_in_channels (int): Input channels of key feature. + query_in_channels (int): Input channels of query feature. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_downsample (nn.Module): Query downsample module. + key_downsample (nn.Module): Key downsample module. + key_query_num_convs (int): Number of convs for key/query projection. + value_num_convs (int): Number of convs for value projection. + matmul_norm (bool): Whether normalize attention map with sqrt of + channels + with_out (bool): Whether use out projection. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, key_in_channels, query_in_channels, channels, + out_channels, share_key_query, query_downsample, + key_downsample, key_query_num_convs, value_out_num_convs, + key_query_norm, value_out_norm, matmul_norm, with_out, + conv_cfg, norm_cfg, act_cfg): + super(SelfAttentionBlock, self).__init__() + if share_key_query: + assert key_in_channels == query_in_channels + self.key_in_channels = key_in_channels + self.query_in_channels = query_in_channels + self.out_channels = out_channels + self.channels = channels + self.share_key_query = share_key_query + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.key_project = self.build_project( + key_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if share_key_query: + self.query_project = self.key_project + else: + self.query_project = self.build_project( + query_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.value_project = self.build_project( + key_in_channels, + channels if with_out else out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if with_out: + self.out_project = self.build_project( + channels, + out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.out_project = None + + self.query_downsample = query_downsample + self.key_downsample = key_downsample + self.matmul_norm = matmul_norm + + self.init_weights() + + def init_weights(self): + """Initialize weight of later layer.""" + if self.out_project is not None: + if not isinstance(self.out_project, ConvModule): + constant_init(self.out_project, 0) + + def build_project(self, in_channels, channels, num_convs, use_conv_module, + conv_cfg, norm_cfg, act_cfg): + """Build projection layer for key/query/value/out.""" + if use_conv_module: + convs = [ + ConvModule( + in_channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ] + for _ in range(num_convs - 1): + convs.append( + ConvModule( + channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + convs = [nn.Conv2d(in_channels, channels, 1)] + for _ in range(num_convs - 1): + convs.append(nn.Conv2d(channels, channels, 1)) + if len(convs) > 1: + convs = nn.Sequential(*convs) + else: + convs = convs[0] + return convs + + def forward(self, query_feats, key_feats): + """Forward function.""" + batch_size = query_feats.size(0) + query = self.query_project(query_feats) + if self.query_downsample is not None: + query = self.query_downsample(query) + query = query.reshape(*query.shape[:2], -1) + query = query.permute(0, 2, 1).contiguous() + + key = self.key_project(key_feats) + value = self.value_project(key_feats) + if self.key_downsample is not None: + key = self.key_downsample(key) + value = self.key_downsample(value) + key = key.reshape(*key.shape[:2], -1) + value = value.reshape(*value.shape[:2], -1) + value = value.permute(0, 2, 1).contiguous() + + sim_map = torch.matmul(query, key) + if self.matmul_norm: + sim_map = (self.channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.reshape(batch_size, -1, *query_feats.shape[2:]) + if self.out_project is not None: + context = self.out_project(context) + return context diff --git a/annotator/uniformer/mmseg/models/utils/up_conv_block.py b/annotator/uniformer/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..378469da76cb7bff6a639e7877b3c275d50490fb --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from annotator.uniformer.mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super(UpConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/annotator/uniformer/mmseg/models/utils/weight_init.py b/annotator/uniformer/mmseg/models/utils/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..38141ba3d61f64ddfc0a31574b4648cbad96d7dd --- /dev/null +++ b/annotator/uniformer/mmseg/models/utils/weight_init.py @@ -0,0 +1,62 @@ +"""Modified from https://github.com/rwightman/pytorch-image- +models/blob/master/timm/models/layers/drop.py.""" + +import math +import warnings + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + """Reference: https://people.sc.fsu.edu/~jburkardt/presentations + /truncated_normal.pdf""" + + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower_bound = norm_cdf((a - mean) / std) + upper_bound = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor (``torch.Tensor``): an n-dimensional `torch.Tensor` + mean (float): the mean of the normal distribution + std (float): the standard deviation of the normal distribution + a (float): the minimum cutoff value + b (float): the maximum cutoff value + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/annotator/uniformer/mmseg/ops/__init__.py b/annotator/uniformer/mmseg/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c --- /dev/null +++ b/annotator/uniformer/mmseg/ops/__init__.py @@ -0,0 +1,4 @@ +from .encoding import Encoding +from .wrappers import Upsample, resize + +__all__ = ['Upsample', 'resize', 'Encoding'] diff --git a/annotator/uniformer/mmseg/ops/encoding.py b/annotator/uniformer/mmseg/ops/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e --- /dev/null +++ b/annotator/uniformer/mmseg/ops/encoding.py @@ -0,0 +1,74 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super(Encoding, self).__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/annotator/uniformer/mmseg/ops/wrappers.py b/annotator/uniformer/mmseg/ops/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78 --- /dev/null +++ b/annotator/uniformer/mmseg/ops/wrappers.py @@ -0,0 +1,50 @@ +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/annotator/uniformer/mmseg/utils/__init__.py b/annotator/uniformer/mmseg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6 --- /dev/null +++ b/annotator/uniformer/mmseg/utils/__init__.py @@ -0,0 +1,4 @@ +from .collect_env import collect_env +from .logger import get_root_logger + +__all__ = ['get_root_logger', 'collect_env'] diff --git a/annotator/uniformer/mmseg/utils/collect_env.py b/annotator/uniformer/mmseg/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..65c2134ddbee9655161237dd0894d38c768c2624 --- /dev/null +++ b/annotator/uniformer/mmseg/utils/collect_env.py @@ -0,0 +1,17 @@ +from annotator.uniformer.mmcv.utils import collect_env as collect_base_env +from annotator.uniformer.mmcv.utils import get_git_hash + +import annotator.uniformer.mmseg as mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print('{}: {}'.format(name, val)) diff --git a/annotator/uniformer/mmseg/utils/logger.py b/annotator/uniformer/mmseg/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..4149d9eda3dfef07490352d22ac40c42460315e4 --- /dev/null +++ b/annotator/uniformer/mmseg/utils/logger.py @@ -0,0 +1,27 @@ +import logging + +from annotator.uniformer.mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + + return logger diff --git a/annotator/util.py b/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/cldm/__pycache__/cldm.cpython-38.pyc b/cldm/__pycache__/cldm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9b5ba5876096c4095e1d75a18aafda94cbac922 Binary files /dev/null and b/cldm/__pycache__/cldm.cpython-38.pyc differ diff --git a/cldm/__pycache__/ddim_hacked.cpython-38.pyc b/cldm/__pycache__/ddim_hacked.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8223207531735b0f5358e75db151b1008c637a83 Binary files /dev/null and b/cldm/__pycache__/ddim_hacked.cpython-38.pyc differ diff --git a/cldm/__pycache__/hack.cpython-38.pyc b/cldm/__pycache__/hack.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4ee0f5cd38fe089a75cc0402e0cf3ea026e1eca Binary files /dev/null and b/cldm/__pycache__/hack.cpython-38.pyc differ diff --git a/cldm/__pycache__/model.cpython-38.pyc b/cldm/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e359881618fa8dc1672b808a5944c31e7d10bb85 Binary files /dev/null and b/cldm/__pycache__/model.cpython-38.pyc differ diff --git a/cldm/cldm.py b/cldm/cldm.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3ac7a575cf4933fc14dfc15dd3cca41cb3f3e8 --- /dev/null +++ b/cldm/cldm.py @@ -0,0 +1,435 @@ +import einops +import torch +import torch as th +import torch.nn as nn + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, +) + +from einops import rearrange, repeat +from torchvision.utils import make_grid +from ldm.modules.attention import SpatialTransformer +from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.util import log_txt_as_img, exists, instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +class ControlledUnetModel(UNetModel): + def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): + hs = [] + with torch.no_grad(): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + + if control is not None: + h += control.pop() + + for i, module in enumerate(self.output_blocks): + if only_mid_control or control is None: + h = torch.cat([h, hs.pop()], dim=1) + else: + h = torch.cat([h, hs.pop() + control.pop()], dim=1) + h = module(h, emb, context) + + h = h.type(x.dtype) + return self.out(h) + + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self.input_hint_block = TimestepEmbedSequential( + conv_nd(dims, hint_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + guided_hint = self.input_hint_block(hint, emb, context) + + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + + +class ControlLDM(LatentDiffusion): + + def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs): + super().__init__(*args, **kwargs) + self.control_model = instantiate_from_config(control_stage_config) + self.control_key = control_key + self.only_mid_control = only_mid_control + self.control_scales = [1.0] * 13 + + @torch.no_grad() + def get_input(self, batch, k, bs=None, *args, **kwargs): + x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) + control = batch[self.control_key] + if bs is not None: + control = control[:bs] + control = control.to(self.device) + control = einops.rearrange(control, 'b h w c -> b c h w') + control = control.to(memory_format=torch.contiguous_format).float() + return x, dict(c_crossattn=[c], c_concat=[control]) + + def apply_model(self, x_noisy, t, cond, *args, **kwargs): + assert isinstance(cond, dict) + diffusion_model = self.model.diffusion_model + + cond_txt = torch.cat(cond['c_crossattn'], 1) + + if cond['c_concat'] is None: + eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control) + else: + control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt) + control = [c * scale for c, scale in zip(control, self.control_scales)] + eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control) + + return eps + + @torch.no_grad() + def get_unconditional_conditioning(self, N): + return self.get_learned_conditioning([""] * N) + + @torch.no_grad() + def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + use_ddim = ddim_steps is not None + + log = dict() + z, c = self.get_input(batch, self.first_stage_key, bs=N) + c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N] + N = min(z.shape[0], N) + n_row = min(z.shape[0], n_row) + log["reconstruction"] = self.decode_first_stage(z) + log["control"] = c_cat * 2.0 - 1.0 + log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N) + uc_cat = c_cat # torch.zeros_like(c_cat) + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + ddim_sampler = DDIMSampler(self) + b, c, h, w = cond["c_concat"][0].shape + shape = (self.channels, h // 8, w // 8) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) + return samples, intermediates + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.control_model.parameters()) + if not self.sd_locked: + params += list(self.model.diffusion_model.output_blocks.parameters()) + params += list(self.model.diffusion_model.out.parameters()) + opt = torch.optim.AdamW(params, lr=lr) + return opt + + def low_vram_shift(self, is_diffusing): + if is_diffusing: + self.model = self.model.cuda() + self.control_model = self.control_model.cuda() + self.first_stage_model = self.first_stage_model.cpu() + self.cond_stage_model = self.cond_stage_model.cpu() + else: + self.model = self.model.cpu() + self.control_model = self.control_model.cpu() + self.first_stage_model = self.first_stage_model.cuda() + self.cond_stage_model = self.cond_stage_model.cuda() diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py new file mode 100644 index 0000000000000000000000000000000000000000..25b1bc947272ad14d7f7e5e4d1809005253b63d0 --- /dev/null +++ b/cldm/ddim_hacked.py @@ -0,0 +1,317 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + model_t = self.model.apply_model(x, t, c) + model_uncond = self.model.apply_model(x, t, unconditional_conditioning) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + num_reference_steps = timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec diff --git a/cldm/hack.py b/cldm/hack.py new file mode 100644 index 0000000000000000000000000000000000000000..454361e9d036cd1a6a79122c2fd16b489e4767b1 --- /dev/null +++ b/cldm/hack.py @@ -0,0 +1,111 @@ +import torch +import einops + +import ldm.modules.encoders.modules +import ldm.modules.attention + +from transformers import logging +from ldm.modules.attention import default + + +def disable_verbosity(): + logging.set_verbosity_error() + print('logging improved.') + return + + +def enable_sliced_attention(): + ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward + print('Enabled sliced_attention.') + return + + +def hack_everything(clip_skip=0): + disable_verbosity() + ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward + ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip + print('Enabled clip hacks.') + return + + +# Written by Lvmin +def _hacked_clip_forward(self, text): + PAD = self.tokenizer.pad_token_id + EOS = self.tokenizer.eos_token_id + BOS = self.tokenizer.bos_token_id + + def tokenize(t): + return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] + + def transformer_encode(t): + if self.clip_skip > 1: + rt = self.transformer(input_ids=t, output_hidden_states=True) + return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) + else: + return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state + + def split(x): + return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] + + def pad(x, p, i): + return x[:i] if len(x) >= i else x + [p] * (i - len(x)) + + raw_tokens_list = tokenize(text) + tokens_list = [] + + for raw_tokens in raw_tokens_list: + raw_tokens_123 = split(raw_tokens) + raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] + raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] + tokens_list.append(raw_tokens_123) + + tokens_list = torch.IntTensor(tokens_list).to(self.device) + + feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') + y = transformer_encode(feed) + z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) + + return z + + +# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py +def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + limit = k.shape[0] + att_step = 1 + q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) + k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) + v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) + + q_chunks.reverse() + k_chunks.reverse() + v_chunks.reverse() + sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + del k, q, v + for i in range(0, limit, att_step): + q_buffer = q_chunks.pop() + k_buffer = k_chunks.pop() + v_buffer = v_chunks.pop() + sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale + + del k_buffer, q_buffer + # attention, what we cannot get enough of, by chunks + + sim_buffer = sim_buffer.softmax(dim=-1) + + sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) + del v_buffer + sim[i:i + att_step, :, :] = sim_buffer + + del sim_buffer + sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) + return self.to_out(sim) diff --git a/cldm/logger.py b/cldm/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8803846f2a8979f87f3cf9ea5b12869439e62f --- /dev/null +++ b/cldm/logger.py @@ -0,0 +1,76 @@ +import os + +import numpy as np +import torch +import torchvision +from PIL import Image +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.distributed import rank_zero_only + + +class ImageLogger(Callback): + def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "image_log", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + return check_idx % self.batch_freq == 0 + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + if not self.disabled: + self.log_img(pl_module, batch, batch_idx, split="train") diff --git a/cldm/model.py b/cldm/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fed3c31ac145b78907c7f771d1d8db6fb32d92ed --- /dev/null +++ b/cldm/model.py @@ -0,0 +1,28 @@ +import os +import torch + +from omegaconf import OmegaConf +from ldm.util import instantiate_from_config + + +def get_state_dict(d): + return d.get('state_dict', d) + + +def load_state_dict(ckpt_path, location='cpu'): + _, extension = os.path.splitext(ckpt_path) + if extension.lower() == ".safetensors": + import safetensors.torch + state_dict = safetensors.torch.load_file(ckpt_path, device=location) + else: + state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) + state_dict = get_state_dict(state_dict) + print(f'Loaded state_dict from [{ckpt_path}]') + return state_dict + + +def create_model(config_path): + config = OmegaConf.load(config_path) + model = instantiate_from_config(config.model).cpu() + print(f'Loaded model config from [{config_path}]') + return model diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c738d8cbad66bbe1666284aef926c326849701 --- /dev/null +++ b/config.py @@ -0,0 +1 @@ +save_memory = False diff --git a/diffusers/__init__.py b/diffusers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..445e63b4d406f3a32f2ee04556a6613d2646893e --- /dev/null +++ b/diffusers/__init__.py @@ -0,0 +1,289 @@ +__version__ = "0.19.0.dev0" + +from .configuration_utils import ConfigMixin +from .utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_inflect_available, + is_invisible_watermark_available, + is_k_diffusion_available, + is_k_diffusion_version, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_scipy_available, + is_torch_available, + is_torchsde_available, + is_transformers_available, + is_transformers_version, + is_unidecode_available, + logging, +) + + +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 +else: + from .pipelines import OnnxRuntimeModel + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_pt_objects import * # noqa F403 +else: + from .models import ( + AutoencoderKL, + ControlNetModel, + ModelMixin, + MultiAdapter, + PriorTransformer, + T2IAdapter, + T5FilmDecoder, + Transformer2DModel, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + UNet3DConditionModel, + VQModel, + ) + from .optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, + ) + from .pipelines import ( + AudioPipelineOutput, + ConsistencyModelPipeline, + DanceDiffusionPipeline, + DDIMPipeline, + DDPMPipeline, + DiffusionPipeline, + DiTPipeline, + ImagePipelineOutput, + KarrasVePipeline, + LDMPipeline, + LDMSuperResolutionPipeline, + PNDMPipeline, + RePaintPipeline, + ScoreSdeVePipeline, + ) + from .schedulers import ( + CMStochasticIterativeScheduler, + DDIMInverseScheduler, + DDIMParallelScheduler, + DDIMScheduler, + DDPMParallelScheduler, + DDPMScheduler, + DEISMultistepScheduler, + DPMSolverMultistepInverseScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + IPNDMScheduler, + KarrasVeScheduler, + KDPM2AncestralDiscreteScheduler, + KDPM2DiscreteScheduler, + PNDMScheduler, + RePaintScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, + UnCLIPScheduler, + UniPCMultistepScheduler, + VQDiffusionScheduler, + ) + from .training_utils import EMAModel + +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_scipy_objects import * # noqa F403 +else: + from .schedulers import LMSDiscreteScheduler + +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 +else: + from .schedulers import DPMSolverSDEScheduler + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipelines import ( + AltDiffusionImg2ImgPipeline, + AltDiffusionPipeline, + AudioLDMPipeline, + CycleDiffusionPipeline, + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ImageTextPipelineOutput, + KandinskyImg2ImgPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + LDMTextToImagePipeline, + PaintByExamplePipeline, + SemanticStableDiffusionPipeline, + ShapEImg2ImgPipeline, + ShapEPipeline, + StableDiffusionAdapterPipeline, + StableDiffusionAttendAndExcitePipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionLDM3DPipeline, + StableDiffusionModelEditingPipeline, + StableDiffusionPanoramaPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPipeline, + StableDiffusionPipelineSafe, + StableDiffusionPix2PixZeroPipeline, + StableDiffusionSAGPipeline, + StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + TextToVideoSDPipeline, + TextToVideoZeroPipeline, + UnCLIPImageVariationPipeline, + UnCLIPPipeline, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + VideoToVideoSDPipeline, + VQDiffusionPipeline, + ) + +try: + if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 +else: + from .pipelines import ( + StableDiffusionXLControlNetPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + ) + +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 +else: + from .pipelines import StableDiffusionKDiffusionPipeline + +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 +else: + from .pipelines import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_librosa_objects import * # noqa F403 +else: + from .pipelines import AudioDiffusionPipeline, Mel + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 +else: + from .pipelines import SpectrogramDiffusionPipeline + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 +else: + from .models.controlnet_flax import FlaxControlNetModel + from .models.modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipelines import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, + ) + + +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 +else: + from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + ) + +try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 +else: + from .pipelines import MidiProcessor diff --git a/diffusers/__pycache__/__init__.cpython-310.pyc b/diffusers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17db3d7b74eb1088c064f01e3ad8d93976b5cafb Binary files /dev/null and b/diffusers/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/__pycache__/__init__.cpython-38.pyc b/diffusers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27ee15d5a0c624e28fb039c9866b4c7b8fdba2ac Binary files /dev/null and b/diffusers/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/__pycache__/configuration_utils.cpython-310.pyc b/diffusers/__pycache__/configuration_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05f83a659005e09d5efacdb9f6e30232e5727e1d Binary files /dev/null and b/diffusers/__pycache__/configuration_utils.cpython-310.pyc differ diff --git a/diffusers/__pycache__/configuration_utils.cpython-38.pyc b/diffusers/__pycache__/configuration_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b24cab046e3e49de39f21b88ef15b8a96e09ef8 Binary files /dev/null and b/diffusers/__pycache__/configuration_utils.cpython-38.pyc differ diff --git a/diffusers/__pycache__/image_processor.cpython-310.pyc b/diffusers/__pycache__/image_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4ce9c9c22a7e849d546cdeed1f7c94cb83d41a9 Binary files /dev/null and b/diffusers/__pycache__/image_processor.cpython-310.pyc differ diff --git a/diffusers/__pycache__/image_processor.cpython-38.pyc b/diffusers/__pycache__/image_processor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3598363605f0c1ec2bdba08dd1f14d173517cc2d Binary files /dev/null and b/diffusers/__pycache__/image_processor.cpython-38.pyc differ diff --git a/diffusers/__pycache__/loaders.cpython-310.pyc b/diffusers/__pycache__/loaders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1318e3ce33889b7df3f51d97bd67ab7fc9e29c98 Binary files /dev/null and b/diffusers/__pycache__/loaders.cpython-310.pyc differ diff --git a/diffusers/__pycache__/loaders.cpython-38.pyc b/diffusers/__pycache__/loaders.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15cde8a510b2473cad8a4b01fe63c3a92b6e891d Binary files /dev/null and b/diffusers/__pycache__/loaders.cpython-38.pyc differ diff --git a/diffusers/__pycache__/optimization.cpython-310.pyc b/diffusers/__pycache__/optimization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb1d462d18b8d02a7b61f8ca1ca9141546a0c8bc Binary files /dev/null and b/diffusers/__pycache__/optimization.cpython-310.pyc differ diff --git a/diffusers/__pycache__/optimization.cpython-38.pyc b/diffusers/__pycache__/optimization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49b58142d3b7fe20008979656b90386f07f27aae Binary files /dev/null and b/diffusers/__pycache__/optimization.cpython-38.pyc differ diff --git a/diffusers/__pycache__/training_utils.cpython-310.pyc b/diffusers/__pycache__/training_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74fe61cfa35f11e3848689a8c493b8d0a4bc750a Binary files /dev/null and b/diffusers/__pycache__/training_utils.cpython-310.pyc differ diff --git a/diffusers/__pycache__/training_utils.cpython-38.pyc b/diffusers/__pycache__/training_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13347a7ba998f3c1dfbfe0b1c97beae04a1b4d7c Binary files /dev/null and b/diffusers/__pycache__/training_utils.cpython-38.pyc differ diff --git a/diffusers/commands/__init__.py b/diffusers/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad4af9199bbe297dbc6679fd9ecb46baa976053 --- /dev/null +++ b/diffusers/commands/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import ArgumentParser + + +class BaseDiffusersCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: ArgumentParser): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/diffusers/commands/diffusers_cli.py b/diffusers/commands/diffusers_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..2016fc19f557fd539782ca2181ec2fe74026340a --- /dev/null +++ b/diffusers/commands/diffusers_cli.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from .env import EnvironmentCommand +from .fp16_safetensors import FP16SafetensorsCommand + + +def main(): + parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") + commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") + + # Register commands + EnvironmentCommand.register_subcommand(commands_parser) + FP16SafetensorsCommand.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() + + +if __name__ == "__main__": + main() diff --git a/diffusers/commands/env.py b/diffusers/commands/env.py new file mode 100644 index 0000000000000000000000000000000000000000..db9de720942b5efcff921d7e2503e3ae8813561e --- /dev/null +++ b/diffusers/commands/env.py @@ -0,0 +1,84 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +from argparse import ArgumentParser + +import huggingface_hub + +from .. import __version__ as version +from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available +from . import BaseDiffusersCLICommand + + +def info_command_factory(_): + return EnvironmentCommand() + + +class EnvironmentCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + download_parser = parser.add_parser("env") + download_parser.set_defaults(func=info_command_factory) + + def run(self): + hub_version = huggingface_hub.__version__ + + pt_version = "not installed" + pt_cuda_available = "NA" + if is_torch_available(): + import torch + + pt_version = torch.__version__ + pt_cuda_available = torch.cuda.is_available() + + transformers_version = "not installed" + if is_transformers_available(): + import transformers + + transformers_version = transformers.__version__ + + accelerate_version = "not installed" + if is_accelerate_available(): + import accelerate + + accelerate_version = accelerate.__version__ + + xformers_version = "not installed" + if is_xformers_available(): + import xformers + + xformers_version = xformers.__version__ + + info = { + "`diffusers` version": version, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Huggingface_hub version": hub_version, + "Transformers version": transformers_version, + "Accelerate version": accelerate_version, + "xFormers version": xformers_version, + "Using GPU in script?": "", + "Using distributed or parallel set-up in script?": "", + } + + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print(self.format_dict(info)) + + return info + + @staticmethod + def format_dict(d): + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/diffusers/commands/fp16_safetensors.py b/diffusers/commands/fp16_safetensors.py new file mode 100644 index 0000000000000000000000000000000000000000..19553c752dce116d01f9816f90ddd3275d8cc302 --- /dev/null +++ b/diffusers/commands/fp16_safetensors.py @@ -0,0 +1,138 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors +""" + +import glob +import json +from argparse import ArgumentParser, Namespace +from importlib import import_module + +import huggingface_hub +import torch +from huggingface_hub import hf_hub_download +from packaging import version + +from ..utils import is_safetensors_available, logging +from . import BaseDiffusersCLICommand + + +def conversion_command_factory(args: Namespace): + return FP16SafetensorsCommand( + args.ckpt_id, + args.fp16, + args.use_safetensors, + args.use_auth_token, + ) + + +class FP16SafetensorsCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + conversion_parser = parser.add_parser("fp16_safetensors") + conversion_parser.add_argument( + "--ckpt_id", + type=str, + help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.", + ) + conversion_parser.add_argument( + "--fp16", action="store_true", help="If serializing the variables in FP16 precision." + ) + conversion_parser.add_argument( + "--use_safetensors", action="store_true", help="If serializing in the safetensors format." + ) + conversion_parser.add_argument( + "--use_auth_token", + action="store_true", + help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.", + ) + conversion_parser.set_defaults(func=conversion_command_factory) + + def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool): + self.logger = logging.get_logger("diffusers-cli/fp16_safetensors") + self.ckpt_id = ckpt_id + self.local_ckpt_dir = f"/tmp/{ckpt_id}" + self.fp16 = fp16 + + if is_safetensors_available(): + self.use_safetensors = use_safetensors + else: + raise ImportError( + "When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`." + ) + + if not self.use_safetensors and not self.fp16: + raise NotImplementedError( + "When `use_safetensors` and `fp16` both are False, then this command is of no use." + ) + + self.use_auth_token = use_auth_token + + def run(self): + if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"): + raise ImportError( + "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub" + " installation." + ) + else: + from huggingface_hub import create_commit + from huggingface_hub._commit_api import CommitOperationAdd + + model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token) + with open(model_index, "r") as f: + pipeline_class_name = json.load(f)["_class_name"] + pipeline_class = getattr(import_module("diffusers"), pipeline_class_name) + self.logger.info(f"Pipeline class imported: {pipeline_class_name}.") + + # Load the appropriate pipeline. We could have use `DiffusionPipeline` + # here, but just to avoid any rough edge cases. + pipeline = pipeline_class.from_pretrained( + self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token + ) + pipeline.save_pretrained( + self.local_ckpt_dir, + safe_serialization=True if self.use_safetensors else False, + variant="fp16" if self.fp16 else None, + ) + self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.") + + # Fetch all the paths. + if self.fp16: + modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*") + elif self.use_safetensors: + modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors") + + # Prepare for the PR. + commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}." + operations = [] + for path in modified_paths: + operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path)) + + # Open the PR. + commit_description = ( + "Variables converted by the [`diffusers`' `fp16_safetensors`" + " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)." + ) + hub_pr_url = create_commit( + repo_id=self.ckpt_id, + operations=operations, + commit_message=commit_message, + commit_description=commit_description, + repo_type="model", + create_pr=True, + ).pr_url + self.logger.info(f"PR created here: {hub_pr_url}.") diff --git a/diffusers/configuration_utils.py b/diffusers/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c8e8919c9fcd48de5a89e0664bd6c00643f515 --- /dev/null +++ b/diffusers/configuration_utils.py @@ -0,0 +1,664 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ConfigMixin base class and utilities.""" +import dataclasses +import functools +import importlib +import inspect +import json +import os +import re +from collections import OrderedDict +from pathlib import PosixPath +from typing import Any, Dict, Tuple, Union + +import numpy as np +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from . import __version__ +from .utils import ( + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + DummyObject, + deprecate, + extract_commit_hash, + http_user_agent, + logging, +) + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +class ConfigMixin: + r""" + Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also + provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and + saving classes that inherit from [`ConfigMixin`]. + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 + + Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite: + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file is saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a config dictionary. + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class is instantiated. Make sure to only load configuration + files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it is loaded) and initiate the Python class. + `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually + overwrite the same named arguments in `config`. + + Returns: + [`ModelMixin`] or [`SchedulerMixin`]: + A model or scheduler object instantiated from a config dictionary. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Load a model or scheduler configuration. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with + [`~ConfigMixin.save_config`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config are returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the `commit_hash` of the loaded configuration are returned. + + Returns: + `dict`: + A dictionary of all the parameters stored in a JSON configuration file. + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) + + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + + commit_hash = extract_commit_hash(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + if not (return_unused_kwargs or return_commit_hash): + return config_dict + + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + + # 0. Copy origin config dict + original_dict = dict(config_dict.items()) + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + # remove general kwargs if present in dict + if "kwargs" in expected_keys: + expected_keys.remove("kwargs") + # remove flax internal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + + # 2. Remove attributes that cannot be expected from expected config attributes + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + + expected_keys_comp_cls = set() + for c in compatible_classes: + expected_keys_c = cls._get_init_keys(c) + expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) + expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) + config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} + + # remove attributes from orig class that cannot be expected + orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + orig_cls = getattr(diffusers_library, orig_cls_name) + unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys + config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 4. Give nice warning if unexpected values have been passed + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + # 5. Give nice info if config attributes are initiliazed to default because they have not been passed + passed_keys = set(init_dict.keys()) + if len(expected_keys - passed_keys) > 0: + logger.info( + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." + ) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes the configuration instance to a JSON string. + + Returns: + `str`: + String containing all the attributes that make up the configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + # Don't save "_ignore_files" or "_use_default_values" + config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save the configuration instance's parameters to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file to save a configuration instance's parameters. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = dict(kwargs.items()) + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + # ignore flax specific attributes + if field.name in self._flax_internal_args: + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") + + # Get positional arguments aligned with kwargs + for i, arg in enumerate(args): + name = fields[i].name + new_kwargs[name] = arg + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) + + getattr(self, "register_to_config")(**new_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init + return cls diff --git a/diffusers/dependency_versions_check.py b/diffusers/dependency_versions_check.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8578c52957bf6c06decb0d97d3139437f0078f --- /dev/null +++ b/diffusers/dependency_versions_check.py @@ -0,0 +1,47 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from .dependency_versions_table import deps +from .utils.versions import require_version, require_version_core + + +# define which module versions we always want to check at run time +# (usually the ones defined in `install_requires` in setup.py) +# +# order specific notes: +# - tqdm must be checked before tokenizers + +pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() +if sys.version_info < (3, 7): + pkgs_to_check_at_runtime.append("dataclasses") +if sys.version_info < (3, 8): + pkgs_to_check_at_runtime.append("importlib_metadata") + +for pkg in pkgs_to_check_at_runtime: + if pkg in deps: + if pkg == "tokenizers": + # must be loaded here, or else tqdm check may fail + from .utils import is_tokenizers_available + + if not is_tokenizers_available(): + continue # not required, check version only if installed + + require_version_core(deps[pkg]) + else: + raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/diffusers/dependency_versions_table.py b/diffusers/dependency_versions_table.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6cb9e816ccf9ac5f558aeae4e7ce75a3c6feeb --- /dev/null +++ b/diffusers/dependency_versions_table.py @@ -0,0 +1,44 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "Pillow": "Pillow", + "accelerate": "accelerate>=0.11.0", + "compel": "compel==0.1.8", + "black": "black~=23.1", + "datasets": "datasets", + "filelock": "filelock", + "flax": "flax>=0.4.1", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "huggingface-hub": "huggingface-hub>=0.13.2", + "requests-mock": "requests-mock==1.10.0", + "importlib_metadata": "importlib_metadata", + "invisible-watermark": "invisible-watermark>=0.2.0", + "isort": "isort>=5.5.4", + "jax": "jax>=0.2.8,!=0.3.2", + "jaxlib": "jaxlib>=0.1.65", + "Jinja2": "Jinja2", + "k-diffusion": "k-diffusion>=0.0.12", + "torchsde": "torchsde", + "note_seq": "note_seq", + "librosa": "librosa", + "numpy": "numpy", + "omegaconf": "omegaconf", + "parameterized": "parameterized", + "protobuf": "protobuf>=3.20.3,<4", + "pytest": "pytest", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "ruff": "ruff>=0.0.241", + "safetensors": "safetensors", + "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", + "scipy": "scipy", + "onnx": "onnx", + "regex": "regex!=2019.12.17", + "requests": "requests", + "tensorboard": "tensorboard", + "torch": "torch>=1.4", + "torchvision": "torchvision", + "transformers": "transformers>=4.25.1", + "urllib3": "urllib3<=2.0.0", +} diff --git a/diffusers/experimental/README.md b/diffusers/experimental/README.md new file mode 100644 index 0000000000000000000000000000000000000000..81a9de81c73728ea41eb6e8617a5429c3c9645ff --- /dev/null +++ b/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# 🧨 Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/diffusers/experimental/__init__.py b/diffusers/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc8155403016dfd8ad7fb78d246f9da9098ac50 --- /dev/null +++ b/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/diffusers/experimental/rl/__init__.py b/diffusers/experimental/rl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b338d3173e12d478b6b6d6fd0e50650a0ab5a4c --- /dev/null +++ b/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/diffusers/experimental/rl/value_guided_sampling.py b/diffusers/experimental/rl/value_guided_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e4af4986faad9c1e81a5cf4ee76138f3db00ab44 --- /dev/null +++ b/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,152 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import tqdm + +from ...models.unet_1d import UNet1DModel +from ...pipelines import DiffusionPipeline +from ...utils import randn_tensor +from ...utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedRLPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Pipeline for sampling actions from a diffusion model trained to predict sequences of states. + + Original implementation inspired by this repository: https://github.com/jannerm/diffuser. + + Parameters: + value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward. + unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this + application is [`DDPMScheduler`]. + env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. + """ + + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = {} + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: # noqa: E722 + pass + self.stds = {} + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: # noqa: E722 + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + + # permute to match dimension for pre-trained models + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + + # TODO: verify deprecation of this kwarg + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # apply conditions to the trajectory (set the initial state) + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = randn_tensor(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/diffusers/image_processor.py b/diffusers/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccf9b465ebd4cd6ce48a40dfe45bbc70d1f3416 --- /dev/null +++ b/diffusers/image_processor.py @@ -0,0 +1,366 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from PIL import Image + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate + + +class VaeImageProcessor(ConfigMixin): + """ + Image processor for VAE. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + do_convert_rgb: bool = False, + ): + super().__init__() + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + @staticmethod + def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: + """ + Convert a PIL image or a list of PIL images to NumPy arrays. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images = np.stack(images, axis=0) + + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def normalize(images): + """ + Normalize an image array to [-1,1]. + """ + return 2.0 * images - 1.0 + + @staticmethod + def denormalize(images): + """ + Denormalize an image array to [0,1]. + """ + return (images / 2 + 0.5).clamp(0, 1) + + @staticmethod + def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: + """ + Converts an image to RGB format. + """ + image = image.convert("RGB") + return image + + def resize( + self, + image: PIL.Image.Image, + height: Optional[int] = None, + width: Optional[int] = None, + ) -> PIL.Image.Image: + """ + Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`. + """ + if height is None: + height = image.height + if width is None: + width = image.width + + width, height = ( + x - x % self.config.vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) + return image + + def preprocess( + self, + image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + height: Optional[int] = None, + width: Optional[int] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + if isinstance(image, supported_formats): + image = [image] + elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(image[0], PIL.Image.Image): + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + if self.config.do_resize: + image = [self.resize(i, height, width) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = self.numpy_to_pt(image) + _, _, height, width = image.shape + if self.config.do_resize and ( + height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 + ): + raise ValueError( + f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" + f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" + ) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + _, channel, height, width = image.shape + + # don't need any preprocess if the image is latents + if channel == 4: + return image + + if self.config.do_resize and ( + height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 + ): + raise ValueError( + f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" + f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" + ) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + image = self.normalize(image) + + return image + + def postprocess( + self, + image: torch.FloatTensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, + ): + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + return image + + if do_denormalize is None: + do_denormalize = [self.config.do_normalize] * image.shape[0] + + image = torch.stack( + [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] + ) + + if output_type == "pt": + return image + + image = self.pt_to_numpy(image) + + if output_type == "np": + return image + + if output_type == "pil": + return self.numpy_to_pil(image) + + +class VaeImageProcessorLDM3D(VaeImageProcessor): + """ + Image processor for VAE LDM3D. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + ): + super().__init__() + + @staticmethod + def numpy_to_pil(images): + """ + Convert a NumPy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image[:, :, :3]) for image in images] + + return pil_images + + @staticmethod + def rgblike_to_depthmap(image): + """ + Args: + image: RGB-like depth image + + Returns: depth map + + """ + return image[:, :, 1] * 2**8 + image[:, :, 2] + + def numpy_to_depth(self, images): + """ + Convert a NumPy depth image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images_depth = images[:, :, :, 3:] + if images.shape[-1] == 6: + images_depth = (images_depth * 255).round().astype("uint8") + pil_images = [ + Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth + ] + elif images.shape[-1] == 4: + images_depth = (images_depth * 65535.0).astype(np.uint16) + pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] + else: + raise Exception("Not supported") + + return pil_images + + def postprocess( + self, + image: torch.FloatTensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, + ): + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if do_denormalize is None: + do_denormalize = [self.config.do_normalize] * image.shape[0] + + image = torch.stack( + [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] + ) + + image = self.pt_to_numpy(image) + + if output_type == "np": + if image.shape[-1] == 6: + image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) + else: + image_depth = image[:, :, :, 3:] + return image[:, :, :, :3], image_depth + + if output_type == "pil": + return self.numpy_to_pil(image), self.numpy_to_depth(image) + else: + raise Exception(f"This type {output_type} is not supported") diff --git a/diffusers/loaders.py b/diffusers/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce5989b5f49504954a9f4157548376a46b3630b --- /dev/null +++ b/diffusers/loaders.py @@ -0,0 +1,1874 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from collections import defaultdict +from contextlib import nullcontext +from io import BytesIO +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import requests +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from torch import nn + +from .models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, + AttnProcessor2_0, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRALinearLayer, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + XFormersAttnProcessor, +) +from .utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + _get_model_file, + deprecate, + is_accelerate_available, + is_omegaconf_available, + is_safetensors_available, + is_transformers_available, + logging, +) +from .utils.import_utils import BACKENDS_MAPPING + + +if is_safetensors_available(): + import safetensors + +if is_transformers_available(): + from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) + +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + +TEXT_INVERSION_NAME = "learned_embeds.bin" +TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" + +CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" +CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" + + +class PatchedLoraProjection(nn.Module): + def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): + super().__init__() + self.regular_linear_layer = regular_linear_layer + + device = self.regular_linear_layer.weight.device + + if dtype is None: + dtype = self.regular_linear_layer.weight.dtype + + self.lora_linear_layer = LoRALinearLayer( + self.regular_linear_layer.in_features, + self.regular_linear_layer.out_features, + network_alpha=network_alpha, + device=device, + dtype=dtype, + rank=rank, + ) + + self.lora_scale = lora_scale + + def forward(self, input): + return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) + + +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, CLIPTextModel): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + +class AttnProcsLayers(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = dict(enumerate(state_dict.keys())) + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + + raise ValueError( + f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." + ) + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = remap_key(key, state_dict) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +class UNet2DConditionLoadersMixin: + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME + + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be + defined in + [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + and be a `torch.nn.Module` class. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a directory (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + network_alpha = kwargs.pop("network_alpha", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) + + if is_lora: + is_new_lora_format = all( + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + ) + if is_new_lora_format: + # Strip the `"unet"` prefix. + is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) + if is_text_encoder_present: + warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." + warnings.warn(warn_message) + unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] + state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} + + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processor = self + for sub_key in key.split("."): + attn_processor = getattr(attn_processor, sub_key) + + if isinstance( + attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) + ): + cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] + attn_processor_class = LoRAAttnAddedKVProcessor + else: + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): + attn_processor_class = LoRAXFormersAttnProcessor + else: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + + attn_processors[key] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, + ) + attn_processors[key].load_state_dict(value_dict) + elif is_custom_diffusion: + custom_diffusion_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + if len(value) == 0: + custom_diffusion_grouped_dict[key] = {} + else: + if "to_out" in key: + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + else: + attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) + custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in custom_diffusion_grouped_dict.items(): + if len(value_dict) == 0: + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None + ) + else: + cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] + hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] + train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=True, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + attn_processors[key].load_state_dict(value_dict) + else: + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." + ) + + # set correct dtype & device + attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} + + # set layers + self.set_attn_processor(attn_processors) + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + **kwargs, + ): + r""" + Save an attention processor to a directory so that it can be reloaded using the + [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save an attention processor to. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.20.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + is_custom_diffusion = any( + isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + for (_, x) in self.attn_processors.items() + ) + if is_custom_diffusion: + model_to_save = AttnProcsLayers( + { + y: x + for (y, x) in self.attn_processors.items() + if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + } + ) + state_dict = model_to_save.state_dict() + for name, attn in self.attn_processors.items(): + if len(attn.state_dict()) == 0: + state_dict[name] = {} + else: + model_to_save = AttnProcsLayers(self.attn_processors) + state_dict = model_to_save.state_dict() + + if weight_name is None: + if safe_serialization: + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE + else: + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME + + # Save the model + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + +class TextualInversionLoaderMixin: + r""" + Load textual inversion tokens and embeddings to the tokenizer and text encoder. + """ + + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): + r""" + Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to + be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or if the textual inversion token is a single vector, the input prompt is returned. + + Parameters: + prompt (`str` or list of `str`): + The prompt or prompts to guide the image generation. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str` or list of `str`: The converted prompt + """ + if not isinstance(prompt, List): + prompts = [prompt] + else: + prompts = prompt + + prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] + + if not isinstance(prompt, List): + return prompts[0] + + return prompts + + def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): + r""" + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. + + Parameters: + prompt (`str`): + The prompt to guide the image generation. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str`: The converted prompt + """ + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def load_textual_inversion( + self, + pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], + token: Optional[Union[str, List[str]]] = None, + **kwargs, + ): + r""" + Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and + Automatic1111 formats are supported). + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): + Can be either one of the following or a list of them: + + - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a + pretrained model hosted on the Hub. + - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual + inversion weights. + - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + token (`str` or `List[str]`, *optional*): + Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a + list, then `token` must also be a list of equal length. + weight_name (`str`, *optional*): + Name of a custom weight file. This should be used when: + + - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight + name such as `text_inv.bin`. + - The saved textual inversion file is in the Automatic1111 format. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + + Example: + + To load a textual inversion embedding vector in 🤗 Diffusers format: + + ```py + from diffusers import StableDiffusionPipeline + import torch + + model_id = "runwayml/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + pipe.load_textual_inversion("sd-concepts-library/cat-toy") + + prompt = "A backpack" + + image = pipe(prompt, num_inference_steps=50).images[0] + image.save("cat-backpack.png") + ``` + + To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first + (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector + locally: + + ```py + from diffusers import StableDiffusionPipeline + import torch + + model_id = "runwayml/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2") + + prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details." + + image = pipe(prompt, num_inference_steps=50).images[0] + image.save("character.png") + ``` + + """ + if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "text_inversion", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path, list): + pretrained_model_name_or_paths = [pretrained_model_name_or_path] + else: + pretrained_model_name_or_paths = pretrained_model_name_or_path + + if isinstance(token, str): + tokens = [token] + elif token is None: + tokens = [None] * len(pretrained_model_name_or_paths) + else: + tokens = token + + if len(pretrained_model_name_or_paths) != len(tokens): + raise ValueError( + f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" + f"Make sure both lists have the same length." + ) + + valid_tokens = [t for t in tokens if t is not None] + if len(set(valid_tokens)) < len(valid_tokens): + raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") + + token_ids_and_embeddings = [] + + for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): + if not isinstance(pretrained_model_name_or_path, dict): + # 1. Load textual inversion file + model_file = None + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path + + # 2. Load token and embedding correcly from file + loaded_token = None + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] + + if token is not None and loaded_token != token: + logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + + # 3. Make sure we don't mess up the tokenizer or text encoder + vocab = self.tokenizer.get_vocab() + if token in vocab: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." + ) + elif f"{token}_1" in vocab: + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 + + raise ValueError( + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) + + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + + if is_multi_vector: + tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + embeddings = [e for e in embedding] # noqa: C416 + else: + tokens = [token] + embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] + + # add tokens and get ids + self.tokenizer.add_tokens(tokens) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + token_ids_and_embeddings += zip(token_ids, embeddings) + + logger.info(f"Loaded textual inversion embedding for {token}.") + + # resize token embeddings and set all new embeddings + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + for token_id, embedding in token_ids_and_embeddings: + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + + +class LoraLoaderMixin: + r""" + Load LoRA layers into [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + """ + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME + + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + + See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into + `self.unet`. + + See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded + into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + + kwargs: + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + """ + state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + self.load_lora_into_text_encoder( + state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale + ) + + @classmethod + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # Convert kohya-ss Style LoRA attn procs to diffusers attn procs + network_alpha = None + if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): + state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict) + + return state_dict, network_alpha + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alpha, unet): + """ + This will load the LoRA layers specified in `state_dict` into `unet` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + """ + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + # Load the layers corresponding to UNet. + unet_keys = [k for k in keys if k.startswith(cls.unet_name)] + logger.info(f"Loading {cls.unet_name}.") + unet_lora_state_dict = { + k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys + } + unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) + + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any prefix. + elif not all( + key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys() + ): + unet.load_attn_procs(state_dict) + warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." + warnings.warn(warn_message) + + @classmethod + def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key shoult be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + """ + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)] + text_encoder_lora_state_dict = { + k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {cls.text_encoder_name}.") + + if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): + # Convert from the old naming convention to the new naming convention. + # + # Previously, the old LoRA layers were stored on the state dict at the + # same level as the attention block i.e. + # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. + # + # This is no actual module at that point, they were monkey patched on to the + # existing module. We want to be able to load them via their actual state dict. + # They're in `PatchedLoraProjection.lora_linear_layer` now. + for name, _ in text_encoder_attn_modules(text_encoder): + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") + + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + + rank = text_encoder_lora_state_dict[ + "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" + ].shape[1] + + cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) + + # set correct dtype & device + text_encoder_lora_state_dict = { + k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) + for k, v in text_encoder_lora_state_dict.items() + } + + load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) + if len(load_state_dict_results.unexpected_keys) != 0: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) + + @property + def lora_scale(self) -> float: + # property function that returns the lora scale which can be set at run time by the pipeline. + # if _lora_scale has not been set, return 1 + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + + @classmethod + def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj = attn_module.q_proj.regular_linear_layer + attn_module.k_proj = attn_module.k_proj.regular_linear_layer + attn_module.v_proj = attn_module.v_proj.regular_linear_layer + attn_module.out_proj = attn_module.out_proj.regular_linear_layer + + @classmethod + def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None): + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + """ + + # First, remove any monkey-patch that might have been applied before + cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) + + lora_parameters = [] + + for _, attn_module in text_encoder_attn_modules(text_encoder): + attn_module.q_proj = PatchedLoraProjection( + attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) + + attn_module.k_proj = PatchedLoraProjection( + attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) + + attn_module.v_proj = PatchedLoraProjection( + attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) + + attn_module.out_proj = PatchedLoraProjection( + attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) + + return lora_parameters + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the UNet. + text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + # Create a flat dictionary. + state_dict = {} + if unet_lora_layers is not None: + weights = ( + unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers + ) + + unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} + state_dict.update(unet_lora_state_dict) + + if text_encoder_lora_layers is not None: + weights = ( + text_encoder_lora_layers.state_dict() + if isinstance(text_encoder_lora_layers, torch.nn.Module) + else text_encoder_lora_layers + ) + + text_encoder_lora_state_dict = { + f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() + } + state_dict.update(text_encoder_lora_state_dict) + + # Save the model + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + @classmethod + def _convert_kohya_lora_to_diffusers(cls, state_dict): + unet_state_dict = {} + te_state_dict = {} + network_alpha = None + + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + if lora_name_alpha in state_dict: + alpha = state_dict[lora_name_alpha].item() + if network_alpha is None: + network_alpha = alpha + elif network_alpha != alpha: + raise ValueError("Network alpha is not consistent") + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = value + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alpha + + def unload_lora_weights(self): + """ + Unloads the LoRA parameters. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + is_unet_lora = all( + isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor)) + for _, processor in self.unet.attn_processors.items() + ) + # Handle attention processors that are a mix of regular attention and AddedKV + # attention. + if is_unet_lora: + is_attn_procs_mixed = all( + isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor)) + for _, processor in self.unet.attn_processors.items() + ) + if not is_attn_procs_mixed: + unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + self.unet.set_attn_processor(unet_attn_proc_cls()) + else: + self.unet.set_default_attn_processor() + + # Safe to call the following regardless of LoRA. + self._remove_text_encoder_monkey_patch() + + +class FromSingleFileMixin: + """ + Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. + """ + + @classmethod + def from_ckpt(cls, *args, **kwargs): + deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead." + deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False) + return cls.from_single_file(*args, **kwargs) + + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` + format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + extract_ema (`bool`, *optional*, defaults to `False`): + Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield + higher quality images for inference. Non-EMA weights are usually better to continue finetuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and + the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to `None`): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to `"pndm"`): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use, + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if + needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by + itself, if needed. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + >>> from diffusers import StableDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" + ... ) + + >>> # Download pipeline from local file + >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt + >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly") + + >>> # Enable float16 and move to GPU + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", + ... torch_dtype=torch.float16, + ... ) + >>> pipeline.to("cuda") + ``` + """ + # import here to avoid circular dependency + from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + extract_ema = kwargs.pop("extract_ema", False) + image_size = kwargs.pop("image_size", None) + scheduler_type = kwargs.pop("scheduler_type", "pndm") + num_in_channels = kwargs.pop("num_in_channels", None) + upcast_attention = kwargs.pop("upcast_attention", None) + load_safety_checker = kwargs.pop("load_safety_checker", True) + prediction_type = kwargs.pop("prediction_type", None) + text_encoder = kwargs.pop("text_encoder", None) + controlnet = kwargs.pop("controlnet", None) + tokenizer = kwargs.pop("tokenizer", None) + + torch_dtype = kwargs.pop("torch_dtype", None) + + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + pipeline_name = cls.__name__ + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + # TODO: For now we only support stable diffusion + stable_unclip = None + model_type = None + + if pipeline_name in [ + "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + ]: + from .models.controlnet import ControlNetModel + from .pipelines.controlnet.multicontrolnet import MultiControlNetModel + + # Model type will be inferred from the checkpoint. + if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)): + raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.") + elif "StableDiffusion" in pipeline_name: + # Model type will be inferred from the checkpoint. + pass + elif pipeline_name == "StableUnCLIPPipeline": + model_type = "FrozenOpenCLIPEmbedder" + stable_unclip = "txt2img" + elif pipeline_name == "StableUnCLIPImg2ImgPipeline": + model_type = "FrozenOpenCLIPEmbedder" + stable_unclip = "img2img" + elif pipeline_name == "PaintByExamplePipeline": + model_type = "PaintByExample" + elif pipeline_name == "LDMTextToImagePipeline": + model_type = "LDMTextToImage" + else: + raise ValueError(f"Unhandled pipeline class: {pipeline_name}") + + # remove huggingface url + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained + ckpt_path = Path(pretrained_model_link_or_path) + if not ckpt_path.is_file(): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = os.path.join(*ckpt_path.parts[:2]) + file_path = os.path.join(*ckpt_path.parts[2:]) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + pretrained_model_link_or_path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + ) + + pipe = download_from_original_stable_diffusion_ckpt( + pretrained_model_link_or_path, + pipeline_class=cls, + model_type=model_type, + stable_unclip=stable_unclip, + controlnet=controlnet, + from_safetensors=from_safetensors, + extract_ema=extract_ema, + image_size=image_size, + scheduler_type=scheduler_type, + num_in_channels=num_in_channels, + upcast_attention=upcast_attention, + load_safety_checker=load_safety_checker, + prediction_type=prediction_type, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + if torch_dtype is not None: + pipe.to(torch_dtype=torch_dtype) + + return pipe + + +class FromOriginalVAEMixin: + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by + default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z + = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution + Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load + a VAE that does accompany a stable diffusion model of v2 or higher or SDXL. + + + + Examples: + + ```py + from diffusers import AutoencoderKL + + url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file + model = AutoencoderKL.from_single_file(url) + ``` + """ + if not is_omegaconf_available(): + raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) + + from omegaconf import OmegaConf + + from .models import AutoencoderKL + + # import here to avoid circular dependency + from .pipelines.stable_diffusion.convert_from_ckpt import ( + convert_ldm_vae_checkpoint, + create_vae_diffusers_config, + ) + + config_file = kwargs.pop("config_file", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + image_size = kwargs.pop("image_size", None) + scaling_factor = kwargs.pop("scaling_factor", None) + kwargs.pop("upcast_attention", None) + + torch_dtype = kwargs.pop("torch_dtype", None) + + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + # remove huggingface url + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained + ckpt_path = Path(pretrained_model_link_or_path) + if not ckpt_path.is_file(): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + pretrained_model_link_or_path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + ) + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu") + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if config_file is None: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(config_file) + + # default to sd-v1-5 + image_size = image_size or 512 + + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if scaling_factor is None: + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + + if torch_dtype is not None: + vae.to(torch_dtype=torch_dtype) + + return vae + + +class FromOriginalControlnetMixin: + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + from diffusers import StableDiffusionControlnetPipeline, ControlNetModel + + url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path + model = ControlNetModel.from_single_file(url) + + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path + pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet) + ``` + """ + # import here to avoid circular dependency + from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt + + config_file = kwargs.pop("config_file", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + num_in_channels = kwargs.pop("num_in_channels", None) + use_linear_projection = kwargs.pop("use_linear_projection", None) + revision = kwargs.pop("revision", None) + extract_ema = kwargs.pop("extract_ema", False) + image_size = kwargs.pop("image_size", None) + upcast_attention = kwargs.pop("upcast_attention", None) + + torch_dtype = kwargs.pop("torch_dtype", None) + + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + # remove huggingface url + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained + ckpt_path = Path(pretrained_model_link_or_path) + if not ckpt_path.is_file(): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + pretrained_model_link_or_path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + ) + + if config_file is None: + config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml" + config_file = BytesIO(requests.get(config_url).content) + + image_size = image_size or 512 + + controlnet = download_controlnet_from_original_ckpt( + pretrained_model_link_or_path, + original_config_file=config_file, + image_size=image_size, + extract_ema=extract_ema, + num_in_channels=num_in_channels, + upcast_attention=upcast_attention, + from_safetensors=from_safetensors, + use_linear_projection=use_linear_projection, + ) + + if torch_dtype is not None: + controlnet.to(torch_dtype=torch_dtype) + + return controlnet diff --git a/diffusers/models/README.md b/diffusers/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fb91f59411265660e01d8b4bcc0b99e8b8fe9d55 --- /dev/null +++ b/diffusers/models/README.md @@ -0,0 +1,3 @@ +# Models + +For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview). \ No newline at end of file diff --git a/diffusers/models/__init__.py b/diffusers/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e330a44691a73f51d109de2e268d179e9e86d87 --- /dev/null +++ b/diffusers/models/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .adapter import MultiAdapter, T2IAdapter + from .autoencoder_kl import AutoencoderKL + from .controlnet import ControlNetModel + from .dual_transformer_2d import DualTransformer2DModel + from .modeling_utils import ModelMixin + from .prior_transformer import PriorTransformer + from .t5_film_transformer import T5FilmDecoder + from .transformer_2d import Transformer2DModel + from .unet_1d import UNet1DModel + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel + from .vq_model import VQModel + +if is_flax_available(): + from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL diff --git a/diffusers/models/__pycache__/__init__.cpython-310.pyc b/diffusers/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d41f8708cf72411f429c0152a8540dd52f8f7366 Binary files /dev/null and b/diffusers/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/__init__.cpython-38.pyc b/diffusers/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7d85c887654ef4ce8dfecee355c598a812da002 Binary files /dev/null and b/diffusers/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/activations.cpython-310.pyc b/diffusers/models/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..430c21f17aff487f040247c33b564d7c1d6ada25 Binary files /dev/null and b/diffusers/models/__pycache__/activations.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/activations.cpython-38.pyc b/diffusers/models/__pycache__/activations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c8a513da41b63b51c355044df7b93c3269dceda Binary files /dev/null and b/diffusers/models/__pycache__/activations.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/adapter.cpython-310.pyc b/diffusers/models/__pycache__/adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d743cc5a8af08b4b44adb636ecdca658881ad37d Binary files /dev/null and b/diffusers/models/__pycache__/adapter.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/adapter.cpython-38.pyc b/diffusers/models/__pycache__/adapter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59fdd45dde4bd0bb24016ed510edc339b2620084 Binary files /dev/null and b/diffusers/models/__pycache__/adapter.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/attention.cpython-310.pyc b/diffusers/models/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70cfc42e993409e73081e637a2047c088a9f4e01 Binary files /dev/null and b/diffusers/models/__pycache__/attention.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/attention.cpython-38.pyc b/diffusers/models/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa7699c170b1a9c2672ce06edd6e3747356a3288 Binary files /dev/null and b/diffusers/models/__pycache__/attention.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/attention_processor.cpython-310.pyc b/diffusers/models/__pycache__/attention_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e9ed21e0de471a8a8d08b8c5271707b55bdafa8 Binary files /dev/null and b/diffusers/models/__pycache__/attention_processor.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/attention_processor.cpython-38.pyc b/diffusers/models/__pycache__/attention_processor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ceeb4fb3d80d220b9978f137d72c43d005b092c Binary files /dev/null and b/diffusers/models/__pycache__/attention_processor.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc b/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b3b7362866f45f2d9a42bd42f7b9e226212d024 Binary files /dev/null and b/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc b/diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bcaa7b767796f38b2e5e3cd61eb09a2d19a2c24 Binary files /dev/null and b/diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/controlnet.cpython-310.pyc b/diffusers/models/__pycache__/controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe07158e28a0e8977c6fe6cb356831025dcebb3b Binary files /dev/null and b/diffusers/models/__pycache__/controlnet.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/controlnet.cpython-38.pyc b/diffusers/models/__pycache__/controlnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4ae9403fa341fa5121c0b00f96cc62d6418aa61 Binary files /dev/null and b/diffusers/models/__pycache__/controlnet.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/controlnet_composer.cpython-310.pyc b/diffusers/models/__pycache__/controlnet_composer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f26ae29eb6d8e6e335d47e27813cec8fbe267c3 Binary files /dev/null and b/diffusers/models/__pycache__/controlnet_composer.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/controlnet_composer.cpython-38.pyc b/diffusers/models/__pycache__/controlnet_composer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c353b814edfb02522241318441c65223d8a2916 Binary files /dev/null and b/diffusers/models/__pycache__/controlnet_composer.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc b/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf9b8d3972723e0ebeca90e6344b1b63a4fbb9aa Binary files /dev/null and b/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc b/diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7755b8183aebb795eb674c34c50fbb46baf41916 Binary files /dev/null and b/diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/embeddings.cpython-310.pyc b/diffusers/models/__pycache__/embeddings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d82a2d6acf832243e90144e2379fc51a2ebd46f Binary files /dev/null and b/diffusers/models/__pycache__/embeddings.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/embeddings.cpython-38.pyc b/diffusers/models/__pycache__/embeddings.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f93c7669346644de0de8adc993b83afd9cbe1a21 Binary files /dev/null and b/diffusers/models/__pycache__/embeddings.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc b/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bc965ae71474dea4a451335a7187a988c483eba Binary files /dev/null and b/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/modeling_utils.cpython-38.pyc b/diffusers/models/__pycache__/modeling_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5490aa849fd18b500efba6fec7c63d4313ef7c1 Binary files /dev/null and b/diffusers/models/__pycache__/modeling_utils.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc b/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b32127351caa0cf733f34629da9e1a80aaab280 Binary files /dev/null and b/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/prior_transformer.cpython-38.pyc b/diffusers/models/__pycache__/prior_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3018b1dfa7e673a3e112583749bf2557bda6cd9e Binary files /dev/null and b/diffusers/models/__pycache__/prior_transformer.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/resnet.cpython-310.pyc b/diffusers/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63bb6d768605db91ae1c025e395ea6729b701284 Binary files /dev/null and b/diffusers/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/resnet.cpython-38.pyc b/diffusers/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5c143d3e5443ab82f60e0b8a71130794c7e777e Binary files /dev/null and b/diffusers/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc b/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d48a732862a603a63551ee63f562dbb4c60ce98c Binary files /dev/null and b/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/t5_film_transformer.cpython-38.pyc b/diffusers/models/__pycache__/t5_film_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17fc2eb047ebb67d14d13a635ed975bcc5c2d9a2 Binary files /dev/null and b/diffusers/models/__pycache__/t5_film_transformer.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc b/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b09e58232c6297a886e2f50901c5a2be045ccf Binary files /dev/null and b/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/transformer_2d.cpython-38.pyc b/diffusers/models/__pycache__/transformer_2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f3784d96ccdcaddfb777906a28e8417ffe23427 Binary files /dev/null and b/diffusers/models/__pycache__/transformer_2d.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc b/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75f463b8e2707cc17655b01e758f624bcc375f32 Binary files /dev/null and b/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/transformer_temporal.cpython-38.pyc b/diffusers/models/__pycache__/transformer_temporal.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41796f0f1898457ffcd63a5c587d15c734b7bd4d Binary files /dev/null and b/diffusers/models/__pycache__/transformer_temporal.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_1d.cpython-310.pyc b/diffusers/models/__pycache__/unet_1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d82d8fdaa98c571329579c6e0ce029044b7054a4 Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_1d.cpython-38.pyc b/diffusers/models/__pycache__/unet_1d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a05ba3ad92d7a1a774c102b5a5859be40f03cd6 Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc b/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db61b43437e3f1376f165c6c06e30508ad3e3c82 Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_1d_blocks.cpython-38.pyc b/diffusers/models/__pycache__/unet_1d_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02206aaf19e703a6f0be9781b70fb12a0647e83c Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d_blocks.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a563a6f94ef8ff56ada1bb45e11957e07db9423f Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a39afdf290f9a560a8fa9f71e4b046046c05f85e Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0ef358590c1b02d291eccad421d9e193514a018 Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_blocks.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bfe3515ccad52b5b724374afb7f9475fbf04877 Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_blocks.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc60cef1f039a325a9d6fef9dbb32690c1cbc67a Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea411acd7fab6fcdde4ff9cda0c0f0a729be1fb1 Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_condition_multi_branch.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_condition_multi_branch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1283f9021f1fc7660bd666fd6612dada0fa003d Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition_multi_branch.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca9aab4d96b39d8b54d160152ce568b534c8e0d7 Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f323b6876388e517b9902d45db6f670252728ac Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc b/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04fdec397efe9eab97c56e7f62828cbbdac13c35 Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_3d_blocks.cpython-38.pyc b/diffusers/models/__pycache__/unet_3d_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e14b6d8e3fc156e8bb4b03f14254e33c6bd06cf Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_blocks.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc b/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0744cf16bd4025d76da50a32f19651c2748f21d4 Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/unet_3d_condition.cpython-38.pyc b/diffusers/models/__pycache__/unet_3d_condition.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdb9505e5e652469d8810beef3ee8b366a98143d Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_condition.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/vae.cpython-310.pyc b/diffusers/models/__pycache__/vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d880c6523df5a735260dab5fd9f578c76334f0a4 Binary files /dev/null and b/diffusers/models/__pycache__/vae.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/vae.cpython-38.pyc b/diffusers/models/__pycache__/vae.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8c9d73944fae981cc7fd4d1301b18edd6853207 Binary files /dev/null and b/diffusers/models/__pycache__/vae.cpython-38.pyc differ diff --git a/diffusers/models/__pycache__/vq_model.cpython-310.pyc b/diffusers/models/__pycache__/vq_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de5463e3803e77fc9e78b1fa22ab9bc8f363a87f Binary files /dev/null and b/diffusers/models/__pycache__/vq_model.cpython-310.pyc differ diff --git a/diffusers/models/__pycache__/vq_model.cpython-38.pyc b/diffusers/models/__pycache__/vq_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4532f2e1c2b4d1b223b408c2ad81c09e0bee161c Binary files /dev/null and b/diffusers/models/__pycache__/vq_model.cpython-38.pyc differ diff --git a/diffusers/models/activations.py b/diffusers/models/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..64759b706e2f108803e51ccd50f9dff67ad49722 --- /dev/null +++ b/diffusers/models/activations.py @@ -0,0 +1,12 @@ +from torch import nn + + +def get_activation(act_fn): + if act_fn in ["swish", "silu"]: + return nn.SiLU() + elif act_fn == "mish": + return nn.Mish() + elif act_fn == "gelu": + return nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") diff --git a/diffusers/models/adapter.py b/diffusers/models/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..a65a3873b130a4b74b46ebfb34b99067ee1a6a6e --- /dev/null +++ b/diffusers/models/adapter.py @@ -0,0 +1,291 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from .modeling_utils import ModelMixin +from .resnet import Downsample2D + + +class MultiAdapter(ModelMixin): + r""" + MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to + user-assigned weighting. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + adapters (`List[T2IAdapter]`, *optional*, defaults to None): + A list of `T2IAdapter` model instances. + """ + + def __init__(self, adapters: List["T2IAdapter"]): + super(MultiAdapter, self).__init__() + + self.num_adapter = len(adapters) + self.adapters = nn.ModuleList(adapters) + + def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: + r""" + Args: + xs (`torch.Tensor`): + (batch, channel, height, width) input images for multiple adapter models concated along dimension 1, + `channel` should equal to `num_adapter` * "number of channel of image". + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding + them together. + """ + if adapter_weights is None: + adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) + else: + adapter_weights = torch.tensor(adapter_weights) + + if xs.shape[1] % self.num_adapter != 0: + raise ValueError( + f"Expecting multi-adapter's input have number of channel that cab be evenly divisible " + f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0" + ) + x_list = torch.chunk(xs, self.num_adapter, dim=1) + accume_state = None + for x, w, adapter in zip(x_list, adapter_weights, self.adapters): + features = adapter(x) + if accume_state is None: + accume_state = features + else: + for i in range(len(features)): + accume_state[i] += w * features[i] + return accume_state + + +class T2IAdapter(ModelMixin, ConfigMixin): + r""" + A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model + generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's + architecture follows the original implementation of + [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97) + and + [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (`int`, *optional*, defaults to 3): + Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale + image as *control image*. + channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will + also determine the number of downsample blocks in the Adapter. + num_res_blocks (`int`, *optional*, defaults to 2): + Number of ResNet blocks in each downsample block + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + adapter_type: str = "full_adapter", + ): + super().__init__() + + if adapter_type == "full_adapter": + self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor) + elif adapter_type == "light_adapter": + self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) + else: + raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'") + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + return self.adapter(x) + + @property + def total_downscale_factor(self): + return self.adapter.total_downscale_factor + + +# full adapter + + +class FullAdapter(nn.Module): + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) + + self.body = nn.ModuleList( + [ + AdapterBlock(channels[0], channels[0], num_res_blocks), + *[ + AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True) + for i in range(1, len(channels)) + ], + ] + ) + + self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.unshuffle(x) + x = self.conv_in(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class AdapterBlock(nn.Module): + def __init__(self, in_channels, out_channels, num_res_blocks, down=False): + super().__init__() + + self.downsample = None + if down: + self.downsample = Downsample2D(in_channels) + + self.in_conv = None + if in_channels != out_channels: + self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + self.resnets = nn.Sequential( + *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], + ) + + def forward(self, x): + if self.downsample is not None: + x = self.downsample(x) + + if self.in_conv is not None: + x = self.in_conv(x) + + x = self.resnets(x) + + return x + + +class AdapterResnetBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels, channels, kernel_size=1) + + def forward(self, x): + h = x + h = self.block1(h) + h = self.act(h) + h = self.block2(h) + + return h + x + + +# light adapter + + +class LightAdapter(nn.Module): + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280], + num_res_blocks: int = 4, + downscale_factor: int = 8, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + + self.body = nn.ModuleList( + [ + LightAdapterBlock(in_channels, channels[0], num_res_blocks), + *[ + LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True) + for i in range(len(channels) - 1) + ], + LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True), + ] + ) + + self.total_downscale_factor = downscale_factor * (2 ** len(channels)) + + def forward(self, x): + x = self.unshuffle(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class LightAdapterBlock(nn.Module): + def __init__(self, in_channels, out_channels, num_res_blocks, down=False): + super().__init__() + mid_channels = out_channels // 4 + + self.downsample = None + if down: + self.downsample = Downsample2D(in_channels) + + self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) + self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) + self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) + + def forward(self, x): + if self.downsample is not None: + x = self.downsample(x) + + x = self.in_conv(x) + x = self.resnets(x) + x = self.out_conv(x) + + return x + + +class LightAdapterResnetBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + + def forward(self, x): + h = x + h = self.block1(h) + h = self.act(h) + h = self.block2(h) + + return h + x diff --git a/diffusers/models/attention.py b/diffusers/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6b05bf35e87f503df3e265bd587d1ca3f32f2bc5 --- /dev/null +++ b/diffusers/models/attention.py @@ -0,0 +1,389 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import maybe_allow_in_graph +from .activations import get_activation +from .attention_processor import Attention +from .embeddings import CombinedTimestepLabelEmbeddings + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class AdaLayerNormZero(nn.Module): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, timestep, class_labels, hidden_dtype=None): + emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaGroupNorm(nn.Module): + """ + GroupNorm layer modified to incorporate timestep embeddings. + """ + + def __init__( + self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x, emb): + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x diff --git a/diffusers/models/attention_flax.py b/diffusers/models/attention_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..0b160d2384311c1fb426b87c11e5fa1572584070 --- /dev/null +++ b/diffusers/models/attention_flax.py @@ -0,0 +1,446 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math + +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): + """Multi-head dot product attention with a limited number of queries.""" + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / jnp.sqrt(k_features) + + @functools.partial(jax.checkpoint, prevent_cse=False) + def summarize_chunk(query, key, value): + attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) + + max_score = jnp.max(attn_weights, axis=-1, keepdims=True) + max_score = jax.lax.stop_gradient(max_score) + exp_weights = jnp.exp(attn_weights - max_score) + + exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) + max_score = jnp.einsum("...qhk->...qh", max_score) + + return (exp_values, exp_weights.sum(axis=-1), max_score) + + def chunk_scanner(chunk_idx): + # julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] + ) + + # julienne value array + value_chunk = jax.lax.dynamic_slice( + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] + ) + + return summarize_chunk(query, key_chunk, value_chunk) + + chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) + + global_max = jnp.max(chunk_max, axis=0, keepdims=True) + max_diffs = jnp.exp(chunk_max - global_max) + + chunk_values *= jnp.expand_dims(max_diffs, axis=-1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(axis=0) + all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) + + return all_values / all_weights + + +def jax_memory_efficient_attention( + query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 +): + r""" + Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 + https://github.com/AminRezaei0x443/memory-efficient-attention + + Args: + query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) + key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) + value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + numerical precision for computation + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array value must divide key_value_length equally without remainder + + Returns: + (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) + """ + num_q, num_heads, q_features = query.shape[-3:] + + def chunk_scanner(chunk_idx, _): + # julienne query array + query_chunk = jax.lax.dynamic_slice( + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] + ) + + return ( + chunk_idx + query_chunk_size, # unused ignore it + _query_chunk_attention( + query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size + ), + ) + + _, res = jax.lax.scan( + f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter + ) + + return jnp.concatenate(res, axis=-3) # fuse the chunked result back + + +class FlaxAttention(nn.Module): + r""" + A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 + + Parameters: + query_dim (:obj:`int`): + Input hidden states dimension + heads (:obj:`int`, *optional*, defaults to 8): + Number of heads + dim_head (:obj:`int`, *optional*, defaults to 64): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim_head * self.heads + self.scale = self.dim_head**-0.5 + + # Weights were exported with old names {to_q, to_k, to_v, to_out} + self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") + self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") + self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") + + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def __call__(self, hidden_states, context=None, deterministic=True): + context = hidden_states if context is None else context + + query_proj = self.query(hidden_states) + key_proj = self.key(context) + value_proj = self.value(context) + + query_states = self.reshape_heads_to_batch_dim(query_proj) + key_states = self.reshape_heads_to_batch_dim(key_proj) + value_states = self.reshape_heads_to_batch_dim(value_proj) + + if self.use_memory_efficient_attention: + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim / 64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim / 16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim / 4) + else: + query_chunk_size = int(flatten_latent_dim) + + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + ) + + hidden_states = hidden_states.transpose(1, 0, 2) + else: + # compute attentions + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) + + # attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.proj_attn(hidden_states) + return self.dropout_layer(hidden_states, deterministic=deterministic) + + +class FlaxBasicTransformerBlock(nn.Module): + r""" + A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: + https://arxiv.org/abs/1706.03762 + + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + only_cross_attention (`bool`, defaults to `False`): + Whether to only apply cross attention. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + """ + dim: int + n_heads: int + d_head: int + dropout: float = 0.0 + only_cross_attention: bool = False + dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False + + def setup(self): + # self attention (or cross_attention if only_cross_attention is True) + self.attn1 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) + # cross attention + self.attn2 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) + self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def __call__(self, hidden_states, context, deterministic=True): + # self attention + residual = hidden_states + if self.only_cross_attention: + hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + else: + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + return self.dropout_layer(hidden_states, deterministic=deterministic) + + +class FlaxTransformer2DModel(nn.Module): + r""" + A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: + https://arxiv.org/pdf/1506.02025.pdf + + + Parameters: + in_channels (:obj:`int`): + Input number of channels + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + depth (:obj:`int`, *optional*, defaults to 1): + Number of transformers block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_linear_projection (`bool`, defaults to `False`): tbd + only_cross_attention (`bool`, defaults to `False`): tbd + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + """ + in_channels: int + n_heads: int + d_head: int + depth: int = 1 + dropout: float = 0.0 + use_linear_projection: bool = False + only_cross_attention: bool = False + dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False + + def setup(self): + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) + + inner_dim = self.n_heads * self.d_head + if self.use_linear_projection: + self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.transformer_blocks = [ + FlaxBasicTransformerBlock( + inner_dim, + self.n_heads, + self.d_head, + dropout=self.dropout, + only_cross_attention=self.only_cross_attention, + dtype=self.dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + for _ in range(self.depth) + ] + + if self.use_linear_projection: + self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def __call__(self, hidden_states, context, deterministic=True): + batch, height, width, channels = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + if self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height * width, channels) + hidden_states = self.proj_in(hidden_states) + else: + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.reshape(batch, height * width, channels) + + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) + + if self.use_linear_projection: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, channels) + else: + hidden_states = hidden_states.reshape(batch, height, width, channels) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states + residual + return self.dropout_layer(hidden_states, deterministic=deterministic) + + +class FlaxFeedForward(nn.Module): + r""" + Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's + [`FeedForward`] class, with the following simplifications: + - The activation function is currently hardcoded to a gated linear unit from: + https://arxiv.org/abs/2002.05202 + - `dim_out` is equal to `dim`. + - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # The second linear layer needs to be called + # net_2 for now to match the index of the Sequential layer + self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) + self.net_2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.net_0(hidden_states, deterministic=deterministic) + hidden_states = self.net_2(hidden_states) + return hidden_states + + +class FlaxGEGLU(nn.Module): + r""" + Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from + https://arxiv.org/abs/2002.05202. + + Parameters: + dim (:obj:`int`): + Input hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.proj(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) diff --git a/diffusers/models/attention_processor.py b/diffusers/models/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..da2920fa671a31e489d3cc207d179e06015ae9f6 --- /dev/null +++ b/diffusers/models/attention_processor.py @@ -0,0 +1,1647 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate, logging, maybe_allow_in_graph +from ..utils.import_utils import is_xformers_available + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, + processor: Optional["AttnProcessor"] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): + is_lora = hasattr(self, "processor") and isinstance( + self.processor, + (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), + ) + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor, out_dim=3): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): + if batch_size is None: + deprecate( + "batch_size=None", + "0.0.15", + ( + "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" + " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" + " `prepare_attention_mask` when preparing the attention_mask." + ), + ) + batch_size = 1 + + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRAAttnProcessor(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomDiffusionAttnProcessor(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv=True, + train_q_out=True, + hidden_size=None, + cross_attention_dim=None, + out_bias=True, + dropout=0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class AttnAddedKVProcessor: + r""" + Processor for performing attention-related computations with extra learnable key and value matrices for the text + encoder. + """ + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class AttnAddedKVProcessor2_0: + r""" + Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra + learnable key and value matrices for the text encoder. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query, out_dim=4) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class LoRAAttnAddedKVProcessor(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text + encoder. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora( + encoder_hidden_states + ) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora( + encoder_hidden_states + ) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states) + value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersAttnAddedKVProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersAttnProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, key_tokens, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAXFormersAttnProcessor(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + + """ + + def __init__( + self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None + ): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.attention_op = attention_op + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product + attention. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomDiffusionXFormersAttnProcessor(nn.Module): + r""" + Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use + as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + """ + + def __init__( + self, + train_kv=True, + train_q_out=False, + hidden_size=None, + cross_attention_dim=None, + out_bias=True, + dropout=0.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class SlicedAttnProcessor: + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SlicedAttnAddedKVProcessor: + r""" + Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +AttentionProcessor = Union[ + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + LoRAAttnAddedKVProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, +] + + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + """ + + def __init__( + self, + f_channels, + zq_channels, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f diff --git a/diffusers/models/autoencoder_kl.py b/diffusers/models/autoencoder_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..2390d2bc58261c76a38cd18dc48dbd7fb59a4d58 --- /dev/null +++ b/diffusers/models/autoencoder_kl.py @@ -0,0 +1,417 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalVAEMixin +from ..utils import BaseOutput, apply_forward_hook +from .attention_processor import AttentionProcessor, AttnProcessor +from .modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a, b, blend_extent): + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a, b, blend_extent): + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/models/controlnet.py b/diffusers/models/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..354cd5851d4c203428b34db20698b3bcb0e4a626 --- /dev/null +++ b/diffusers/models/controlnet.py @@ -0,0 +1,823 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalControlnetMixin +from ..utils import BaseOutput, logging +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from .unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/diffusers/models/controlnet_composer.py b/diffusers/models/controlnet_composer.py new file mode 100644 index 0000000000000000000000000000000000000000..57ba3db9ede36664f2f366078276701529f3cda9 --- /dev/null +++ b/diffusers/models/controlnet_composer.py @@ -0,0 +1,886 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalControlnetMixin +from ..utils import BaseOutput, logging +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from .unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + cond_num: int = 3, + fusion: str = "sum", + normalize_to_0_1: bool = True, + ): + super().__init__() + + self.cond_num = cond_num + self.fusion = fusion + self.normalize_to_0_1 = normalize_to_0_1 + self.conv_in_list = nn.ModuleList([]) + for i in range(self.cond_num): + self.conv_in_list.append(nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)) + # self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks_list = nn.ModuleList([]) + for i in range(self.cond_num): + blocks = nn.ModuleList([]) + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + self.blocks_list.append(blocks) + + self.conv_out_list = nn.ModuleList([]) + for i in range(self.cond_num): + self.conv_out_list.append(zero_module(nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1))) + + if self.fusion == "sum": + pass + elif self.fusion == "avg": + pass + elif self.fusion == "learn": + self.fusion_conv = zero_module(nn.Conv2d(conditioning_embedding_channels * self.cond_num, conditioning_embedding_channels, kernel_size=3, padding=1)) + else: + assert False + + def forward(self, conditioning, gating_matrix=None): + assert len(conditioning) == self.cond_num + + if self.normalize_to_0_1: + conditioning = [x / 2. + 0.5 for x in conditioning] + + embedding_list = [] + for i in range(self.cond_num): + embedding = self.conv_in_list[i](conditioning[i]) + embedding = F.silu(embedding) + + for block in self.blocks_list[i]: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out_list[i](embedding) + embedding_list.append(embedding) + + if gating_matrix is None: + if self.fusion == "sum" or self.fusion == "avg": + stacked_tensor = torch.stack(embedding_list, dim=0) + embedding = torch.sum(stacked_tensor, dim=0) + if self.fusion == "avg": + embedding = embedding / self.cond_num + elif self.fusion == "learn": + concat_tensor = torch.cat(embedding_list, dim=1) + embedding = self.fusion_conv(concat_tensor) + else: + assert False + else: + # embedding shape is (B, 3, C, H, W) + # gating matrix shape is (B, 3, H, W) + # (B, 3, C, H, W) x (B, 3, H, W) -> (B, C, H, W) + embedding = torch.stack(embedding_list, dim=1) # (B, 3, C, H, W) + gating_matrix = gating_matrix.unsqueeze(2) # (B, 3, 1, H, W) + embedding = embedding * gating_matrix + embedding = embedding.sum(dim=1) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, + cond_num: int = 3, + fusion: str = "sum", + normalize_to_0_1: bool = True, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + cond_num=cond_num, + fusion=fusion, + normalize_to_0_1=normalize_to_0_1, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + cond_num: int = 3, + fusion: str = "sum", + normalize_to_0_1: bool = True, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + cond_num=cond_num, + fusion=fusion, + normalize_to_0_1=normalize_to_0_1, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + # controlnet_cond: torch.FloatTensor, + controlnet_cond: List[torch.FloatTensor], + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + gating_matrix=None, + ) -> Union[ControlNetOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + for cond in controlnet_cond: + cond = torch.flip(cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + cond_embedding = self.controlnet_cond_embedding(controlnet_cond, gating_matrix=gating_matrix) + + sample = sample + cond_embedding + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/diffusers/models/controlnet_flax.py b/diffusers/models/controlnet_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..a826df48e41a632454c513877ec55be7f86089f9 --- /dev/null +++ b/diffusers/models/controlnet_flax.py @@ -0,0 +1,394 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .modeling_flax_utils import FlaxModelMixin +from .unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, +) + + +@flax.struct.dataclass +class FlaxControlNetOutput(BaseOutput): + """ + The output of [`FlaxControlNetModel`]. + + Args: + down_block_res_samples (`jnp.ndarray`): + mid_block_res_sample (`jnp.ndarray`): + """ + + down_block_res_samples: jnp.ndarray + mid_block_res_sample: jnp.ndarray + + +class FlaxControlNetConditioningEmbedding(nn.Module): + conditioning_embedding_channels: int + block_out_channels: Tuple[int] = (16, 32, 96, 256) + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_in = nn.Conv( + self.block_out_channels[0], + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + blocks = [] + for i in range(len(self.block_out_channels) - 1): + channel_in = self.block_out_channels[i] + channel_out = self.block_out_channels[i + 1] + conv1 = nn.Conv( + channel_in, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv1) + conv2 = nn.Conv( + channel_out, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv2) + self.blocks = blocks + + self.conv_out = nn.Conv( + self.conditioning_embedding_channels, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = nn.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = nn.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +@flax_register_to_config +class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A ControlNet model. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods + implemented for all models (such as downloading or saving). + + This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + """ + sample_size: int = 32 + in_channels: int = 4 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + only_cross_attention: Union[bool, Tuple[bool]] = False + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: Union[int, Tuple[int]] = 8 + num_attention_heads: Optional[Union[int, Tuple[int]]] = None + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + controlnet_conditioning_channel_order: str = "rgb" + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) + + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) + controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] + + def setup(self): + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = self.num_attention_heads or self.attention_head_dim + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=self.conditioning_embedding_out_channels, + ) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + + # down + down_blocks = [] + controlnet_down_blocks = [] + + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + num_attention_heads=num_attention_heads[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + + for _ in range(self.layers_per_block): + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + self.down_blocks = down_blocks + self.controlnet_down_blocks = controlnet_down_blocks + + # mid + mid_block_channel = block_out_channels[-1] + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=mid_block_channel, + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + use_linear_projection=self.use_linear_projection, + dtype=self.dtype, + ) + + self.controlnet_mid_block = nn.Conv( + mid_block_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + controlnet_cond, + conditioning_scale: float = 1.0, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxControlNetOutput, Tuple]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor + conditioning_scale: (`float`) the scale factor for controlnet outputs + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + channel_order = self.controlnet_conditioning_channel_order + if channel_order == "bgr": + controlnet_cond = jnp.flip(controlnet_cond, axis=1) + + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + # 5. contronet blocks + controlnet_down_block_res_samples = () + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return FlaxControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) diff --git a/diffusers/models/cross_attention.py b/diffusers/models/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..44bc156b34cfa8536bdac0fee34709dfd66ae488 --- /dev/null +++ b/diffusers/models/cross_attention.py @@ -0,0 +1,94 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import deprecate +from .attention_processor import ( # noqa: F401 + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor2_0, + LoRAAttnProcessor, + LoRALinearLayer, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + SlicedAttnProcessor, + XFormersAttnProcessor, +) +from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 + + +deprecate( + "cross_attention", + "0.20.0", + "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.", + standard_warn=False, +) + + +AttnProcessor = AttentionProcessor + + +class CrossAttention(Attention): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class CrossAttnProcessor(AttnProcessorRename): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class LoRACrossAttnProcessor(LoRAAttnProcessor): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class XFormersCrossAttnProcessor(XFormersAttnProcessor): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class SlicedCrossAttnProcessor(SlicedAttnProcessor): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) + + +class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): + def __init__(self, *args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) diff --git a/diffusers/models/dual_transformer_2d.py b/diffusers/models/dual_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3db7e73ca6afc5fa7c67c1902d79e67c1aa728bc --- /dev/null +++ b/diffusers/models/dual_transformer_2d.py @@ -0,0 +1,151 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import nn + +from .transformer_2d import Transformer2DModel, Transformer2DModelOutput + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in Attention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) diff --git a/diffusers/models/embeddings.py b/diffusers/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a0c5549ee9d282b4eaa41d496255ad26b74699 --- /dev/null +++ b/diffusers/models/embeddings.py @@ -0,0 +1,546 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import numpy as np +import torch +from torch import nn + +from .activations import get_activation + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + return latent + self.pos_embed + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + self.weight = self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + +class ImageProjection(nn.Module): + def __init__( + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + batch_size = image_embeds.shape[0] + + # image + image_embeds = self.image_embeds(image_embeds) + image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + image_embeds = self.norm(image_embeds) + return image_embeds + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + +class ImageTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + + def forward(self, image_embeds: torch.FloatTensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + return time_image_embeds + + +class ImageHintTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(256, 4, 3, padding=1), + ) + + def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + hint = self.input_hint_block(hint) + return time_image_embeds, hint + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token diff --git a/diffusers/models/embeddings_flax.py b/diffusers/models/embeddings_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..88c2c45e4655b8013fa96e0b4408e3ec0a87c2c7 --- /dev/null +++ b/diffusers/models/embeddings_flax.py @@ -0,0 +1,95 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import flax.linen as nn +import jax.numpy as jnp + + +def get_sinusoidal_embeddings( + timesteps: jnp.ndarray, + embedding_dim: int, + freq_shift: float = 1, + min_timescale: float = 1, + max_timescale: float = 1.0e4, + flip_sin_to_cos: bool = False, + scale: float = 1.0, +) -> jnp.ndarray: + """Returns the positional encoding (same as Tensor2Tensor). + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + embedding_dim: The number of output channels. + min_timescale: The smallest time unit (should probably be 0.0). + max_timescale: The largest time unit. + Returns: + a Tensor of timing signals [N, num_channels] + """ + assert timesteps.ndim == 1, "Timesteps should be a 1d-array" + assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" + num_timescales = float(embedding_dim // 2) + log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) + inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) + emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) + + # scale embeddings + scaled_time = scale * emb + + if flip_sin_to_cos: + signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) + else: + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) + signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) + return signal + + +class FlaxTimestepEmbedding(nn.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + time_embed_dim: int = 32 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, temb): + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) + temb = nn.silu(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) + return temb + + +class FlaxTimesteps(nn.Module): + r""" + Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 + + Args: + dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + """ + dim: int = 32 + flip_sin_to_cos: bool = False + freq_shift: float = 1 + + @nn.compact + def __call__(self, timesteps): + return get_sinusoidal_embeddings( + timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift + ) diff --git a/diffusers/models/modeling_flax_pytorch_utils.py b/diffusers/models/modeling_flax_pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9de83f87dab84d2e7fdd77b835db787cb4f1cb6 --- /dev/null +++ b/diffusers/models/modeling_flax_pytorch_utils.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - Flax general utilities.""" +import re + +import jax.numpy as jnp +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +def rename_key(key): + regex = r"\w+[.]\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +##################### +# PyTorch => Flax # +##################### + + +# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + + # conv norm or layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if ( + any("norm" in str_ for str_ in pt_tuple_key) + and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) + and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) + ): + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + + # embedding + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight": + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): + # Step 1: Convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + # Step 2: Since the model is stateless, get random Flax params + random_flax_params = flax_model.init_weights(PRNGKey(init_key)) + + random_flax_state_dict = flatten_dict(random_flax_params) + flax_state_dict = {} + + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + return unflatten_dict(flax_state_dict) diff --git a/diffusers/models/modeling_flax_utils.py b/diffusers/models/modeling_flax_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6e1b3bba3d94e0252794cd0eda079f2c6f4183 --- /dev/null +++ b/diffusers/models/modeling_flax_utils.py @@ -0,0 +1,534 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pickle import UnpicklingError +from typing import Any, Dict, Union + +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from .. import __version__, is_torch_available +from ..utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + logging, +) +from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + + +logger = logging.get_logger(__name__) + + +class FlaxModelMixin: + r""" + Base class for all Flax models. + + [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _flax_internal_args = ["name", "parent", "dtype"] + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_flatten(mask) + + for masked, key in zip(flat_mask, flat_params.keys()): + if masked: + param = flat_params[key] + flat_params[key] = conditional_cast(param) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # load model + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> params = model.to_bf16(params) + >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> flat_params = traverse_util.flatten_dict(params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> params = model.to_bf16(params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # Download model and configuration from huggingface.co + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> params = model.to_f16(params) + >>> # now cast back to fp32 + >>> params = model.to_fp32(params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # load model + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> params = model.to_fp16(params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> flat_params = traverse_util.flatten_dict(params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> params = model.to_fp16(params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + def init_weights(self, rng: jax.random.KeyArray) -> Dict: + raise NotImplementedError(f"init_weights method has to be implemented for {self}") + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + **kwargs, + ): + r""" + Instantiate a pretrained Flax model from a pretrained model configuration. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + using [`~FlaxModelMixin.save_pretrained`]. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified, all the computation will be performed with the given `dtype`. + + + + This only specifies the dtype of the *computation* and does not influence the dtype of model + parameters. + + If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and + [`~FlaxModelMixin.to_bf16`]. + + + + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments are passed to the underlying model's `__init__` method. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it is loaded) and initiate the model (for + example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying + model's `__init__` method (we assume all relevant updates to the configuration have already been + done). + - If a configuration is not provided, `kwargs` are first passed to the configuration class + initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds + to a configuration attribute is used to override said attribute with the supplied `kwargs` value. + Remaining keys that do not correspond to any configuration attribute are passed to the underlying + model's `__init__` function. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + from_pt = kwargs.pop("from_pt", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "flax", + } + + # Load config if we don't provide a configuration + config_path = config if config is not None else pretrained_model_name_or_path + model, model_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + # model args + dtype=dtype, + **kwargs, + ) + + # Load model + pretrained_path_with_subfolder = ( + pretrained_model_name_or_path + if subfolder is None + else os.path.join(pretrained_model_name_or_path, subfolder) + ) + if os.path.isdir(pretrained_path_with_subfolder): + if from_pt: + if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " + ) + model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) + # Check if pytorch weights exist instead + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + raise EnvironmentError( + f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" + " using `from_pt=True`." + ) + else: + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_path_with_subfolder}." + ) + else: + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" + " internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + if from_pt: + if is_torch_available(): + from .modeling_utils import load_state_dict + else: + raise EnvironmentError( + "Can't load the model in PyTorch format because PyTorch is not installed. " + "Please, install PyTorch or use native Flax weights." + ) + + # Step 1: Get the pytorch file + pytorch_model_file = load_state_dict(model_file) + + # Step 2: Convert the weights + state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) + else: + try: + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.ndarray + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) + + # flatten dicts + state = flatten_dict(state) + + params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) + required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + shape_state = flatten_dict(unfreeze(params_shape_tree)) + + missing_keys = required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - required_params + + if missing_keys: + logger.warning( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + "Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + + for key in state.keys(): + if key in shape_state and state[key].shape != shape_state[key].shape: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " + ) + + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + is_main_process: bool = True, + ): + """ + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~FlaxModelMixin.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a model and its configuration file to. Will be created if it doesn't exist. + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # save model + output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + with open(output_model_file, "wb") as f: + model_bytes = to_bytes(params) + f.write(model_bytes) + + logger.info(f"Model weights saved in {output_model_file}") diff --git a/diffusers/models/modeling_pytorch_flax_utils.py b/diffusers/models/modeling_pytorch_flax_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a61638ad02f7a38a1439f35dea5966c7c7d519d8 --- /dev/null +++ b/diffusers/models/modeling_pytorch_flax_utils.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - Flax general utilities.""" + +from pickle import UnpicklingError + +import jax +import jax.numpy as jnp +import numpy as np +from flax.serialization import from_bytes +from flax.traverse_util import flatten_dict + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +##################### +# Flax => PyTorch # +##################### + + +# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 +def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): + try: + with open(model_file, "rb") as flax_state_f: + flax_state = from_bytes(None, flax_state_f.read()) + except UnpicklingError as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + + return load_flax_weights_in_pytorch_model(pt_model, flax_state) + + +def load_flax_weights_in_pytorch_model(pt_model, flax_state): + """Load flax checkpoints in a PyTorch model""" + + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + + # check if we have bf16 weights + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + if any(is_type_bf16): + # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 + + # and bf16 is not fully supported in PT yet. + logger.warning( + "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " + "before loading those in PyTorch model." + ) + flax_state = jax.tree_util.tree_map( + lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state + ) + + pt_model.base_model_prefix = "" + + flax_state_dict = flatten_dict(flax_state, sep=".") + pt_model_dict = pt_model.state_dict() + + # keep track of unexpected & missing keys + unexpected_keys = [] + missing_keys = set(pt_model_dict.keys()) + + for flax_key_tuple, flax_tensor in flax_state_dict.items(): + flax_key_tuple_array = flax_key_tuple.split(".") + + if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) + elif flax_key_tuple_array[-1] == "kernel": + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + flax_tensor = flax_tensor.T + elif flax_key_tuple_array[-1] == "scale": + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + + if "time_embedding" not in flax_key_tuple_array: + for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): + flax_key_tuple_array[i] = ( + flax_key_tuple_string.replace("_0", ".0") + .replace("_1", ".1") + .replace("_2", ".2") + .replace("_3", ".3") + .replace("_4", ".4") + .replace("_5", ".5") + .replace("_6", ".6") + .replace("_7", ".7") + .replace("_8", ".8") + .replace("_9", ".9") + ) + + flax_key = ".".join(flax_key_tuple_array) + + if flax_key in pt_model_dict: + if flax_tensor.shape != pt_model_dict[flax_key].shape: + raise ValueError( + f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " + f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + else: + # add weight to pytorch dict + flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor + pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) + # remove from missing keys + missing_keys.remove(flax_key) + else: + # weight is not expected by PyTorch model + unexpected_keys.append(flax_key) + + pt_model.load_state_dict(pt_model_dict) + + # re-transform missing_keys to list + missing_keys = list(missing_keys) + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the Flax model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" + f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " FlaxBertForSequenceClassification model)." + ) + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + + return pt_model diff --git a/diffusers/models/modeling_utils.py b/diffusers/models/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa96514c0a9e39b9321550f6d85a8e11b0deb36 --- /dev/null +++ b/diffusers/models/modeling_utils.py @@ -0,0 +1,980 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import itertools +import os +import re +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, device, nn + +from .. import __version__ +from ..utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_safetensors_available, + is_torch_version, + logging, +) + + +logger = logging.get_logger(__name__) + + +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + +if is_safetensors_available(): + import safetensors + + +def get_parameter_device(parameter: torch.nn.Module): + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): + return torch.load(checkpoint_file, map_location="cpu") + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +class ModelMixin(torch.nn.Module): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + + def __init__(self): + super().__init__() + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite + __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + return super().__getattr__(name) + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up during + inference. Speed up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import UNet2DConditionModel + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 + ... ) + >>> model = model.to("cuda") + >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~models.ModelMixin.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a model and its configuration file to. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + """ + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + + # Save the model + if safe_serialization: + safetensors.torch.save_file( + state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} + ) + else: + torch.save(state_dict, os.path.join(save_directory, weights_name)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + unexpected_keys = [] + + empty_state_dict = model.state_dict() + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set( + inspect.signature(set_module_tensor_to_device).parameters.keys() + ) + + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + + if empty_state_dict[param_name].shape != param.shape: + raise ValueError( + f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + + if accepts_dtype: + set_module_tensor_to_device( + model, param_name, param_device, value=param, dtype=torch_dtype + ) + else: + set_module_tensor_to_device(model, param_name, param_device, value=param) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by default the device_map is None and the weights are loaded on the CPU + try: + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + except AttributeError as e: + # When using accelerate loading, we do not have the ability to load the state + # dict and rename the weight names manually. Additionally, accelerate skips + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` + # (which look like they should be private variables?), so we can't use the standard hooks + # to rename parameters on load. We need to mimic the original weight names so the correct + # attributes are available. After we have loaded the weights, we convert the deprecated + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert + # the weights so we don't have to do this again. + + if "'Attention' object has no attribute" in str(e): + logger.warn( + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," + " please also re-upload it or open a PR on the original repository." + ) + model._temp_convert_self_to_deprecated_attention_blocks() + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + model._undo_temp_convert_self_to_deprecated_attention_blocks() + else: + raise e + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (trainable or non-embedding) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters. + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embedding parameters. + + Returns: + `int`: The number of parameters. + + Example: + + ```py + from diffusers import UNet2DConditionModel + + model_id = "runwayml/stable-diffusion-v1-5" + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") + unet.num_parameters(only_trainable=True) + 859520964 + ``` + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def _convert_deprecated_attention_blocks(self, state_dict): + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + + def _temp_convert_self_to_deprecated_attention_blocks(self): + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.query = module.to_q + module.key = module.to_k + module.value = module.to_v + module.proj_attn = module.to_out[0] + + # We don't _have_ to delete the old attributes, but it's helpful to ensure + # that _all_ the weights are loaded into the new attributes and we're not + # making an incorrect assumption that this model should be converted when + # it really shouldn't be. + del module.to_q + del module.to_k + del module.to_v + del module.to_out + + def _undo_temp_convert_self_to_deprecated_attention_blocks(self): + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.to_q = module.query + module.to_k = module.key + module.to_v = module.value + module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) + + del module.query + del module.key + del module.value + del module.proj_attn diff --git a/diffusers/models/prior_transformer.py b/diffusers/models/prior_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3c61dd7561742114947e3419c19fec8c2a824f --- /dev/null +++ b/diffusers/models/prior_transformer.py @@ -0,0 +1,364 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .attention import BasicTransformerBlock +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + The output of [`PriorTransformer`]. + + Args: + predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: torch.FloatTensor + + +class PriorTransformer(ModelMixin, ConfigMixin): + """ + A Prior Transformer model. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` + num_embeddings (`int`, *optional*, defaults to 77): + The number of embeddings of the model input `hidden_states` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + time_embed_act_fn (`str`, *optional*, defaults to 'silu'): + The activation function to use to create timestep embeddings. + norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before + passing to Transformer blocks. Set it to `None` if normalization is not needed. + embedding_proj_norm_type (`str`, *optional*, defaults to None): + The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not + needed. + encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): + The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if + `encoder_hidden_states` is `None`. + added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. + Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot + product between the text embedding and image embedding as proposed in the unclip paper + https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. + time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. + If None, will be set to `num_attention_heads * attention_head_dim` + embedding_proj_dim (`int`, *optional*, default to None): + The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. + clip_embed_dim (`int`, *optional*, default to None): + The dimension of the output. If None, will be set to `embedding_dim`. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + num_layers: int = 20, + embedding_dim: int = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: float = 0.0, + time_embed_act_fn: str = "silu", + norm_in_type: Optional[str] = None, # layer + embedding_proj_norm_type: Optional[str] = None, # layer + encoder_hid_proj_type: Optional[str] = "linear", # linear + added_emb_type: Optional[str] = "prd", # prd + time_embed_dim: Optional[int] = None, + embedding_proj_dim: Optional[int] = None, + clip_embed_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + + time_embed_dim = time_embed_dim or inner_dim + embedding_proj_dim = embedding_proj_dim or embedding_dim + clip_embed_dim = clip_embed_dim or embedding_dim + + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) + + self.proj_in = nn.Linear(embedding_dim, inner_dim) + + if embedding_proj_norm_type is None: + self.embedding_proj_norm = None + elif embedding_proj_norm_type == "layer": + self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) + else: + raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") + + self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) + + if encoder_hid_proj_type is None: + self.encoder_hidden_states_proj = None + elif encoder_hid_proj_type == "linear": + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + else: + raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") + + self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) + + if added_emb_type == "prd": + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + elif added_emb_type is None: + self.prd_embedding = None + else: + raise ValueError( + f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + + if norm_in_type == "layer": + self.norm_in = nn.LayerNorm(inner_dim) + elif norm_in_type is None: + self.norm_in = None + else: + raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") + + self.norm_out = nn.LayerNorm(inner_dim) + + self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) + + causal_attention_mask = torch.full( + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) + + self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def forward( + self, + hidden_states, + timestep: Union[torch.Tensor, float, int], + proj_embedding: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + return_dict: bool = True, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The currently predicted image embeddings. + timestep (`torch.LongTensor`): + Current denoising step. + proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain + tuple. + + Returns: + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: + If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) + + timesteps_projected = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timesteps_projected = timesteps_projected.to(dtype=self.dtype) + time_embeddings = self.time_embedding(timesteps_projected) + + if self.embedding_proj_norm is not None: + proj_embedding = self.embedding_proj_norm(proj_embedding) + + proj_embeddings = self.embedding_proj(proj_embedding) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + + hidden_states = self.proj_in(hidden_states) + + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + + additional_embeds = [] + additional_embeddings_len = 0 + + if encoder_hidden_states is not None: + additional_embeds.append(encoder_hidden_states) + additional_embeddings_len += encoder_hidden_states.shape[1] + + if len(proj_embeddings.shape) == 2: + proj_embeddings = proj_embeddings[:, None, :] + + if len(hidden_states.shape) == 2: + hidden_states = hidden_states[:, None, :] + + additional_embeds = additional_embeds + [ + proj_embeddings, + time_embeddings[:, None, :], + hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + additional_embeds.append(prd_embedding) + + hidden_states = torch.cat( + additional_embeds, + dim=1, + ) + + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = F.pad( + positional_embeddings, + ( + 0, + 0, + additional_embeddings_len, + self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, + ), + value=0.0, + ) + + hidden_states = hidden_states + positional_embeddings + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + hidden_states = self.norm_out(hidden_states) + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings_len:] + + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + + if not return_dict: + return (predicted_image_embedding,) + + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) + + def post_process_latents(self, prior_latents): + prior_latents = (prior_latents * self.clip_std) + self.clip_mean + return prior_latents diff --git a/diffusers/models/resnet.py b/diffusers/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..24c3b07e7cb65447ad996b00066d42a74700dd97 --- /dev/null +++ b/diffusers/models/resnet.py @@ -0,0 +1,877 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_activation +from .attention import AdaGroupNorm +from .attention_processor import SpatialNorm + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class Downsample1D(nn.Module): + """A 1D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + return self.conv(inputs) + + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class FirUpsample2D(nn.Module): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convH = weight.shape[2] + convW = weight.shape[3] + inC = weight.shape[1] + + pad_value = (kernel.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convH, + (hidden_states.shape[3] - 1) * factor + convW, + ) + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // inC + + # Transpose weights. + weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) + + inverse_conv = F.conv_transpose2d( + hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 + ) + + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + + return output + + def forward(self, hidden_states): + if self.use_conv: + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return height + + +class FirDownsample2D(nn.Module): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight: + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * + factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and + same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, convH, convW = weight.shape + pad_value = (kernel.shape[0] - factor) + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + + return output + + def forward(self, hidden_states): + if self.use_conv: + downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return hidden_states + + +# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead +class KDownsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs): + inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) + weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv2d(inputs, weight, stride=2) + + +class KUpsample2D(nn.Module): + def __init__(self, pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) + + def forward(self, inputs): + inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) + weight[indices, indices] = kernel + return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) + + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or + "ada_group" for a stronger conditioning with scale and shift. + kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + skip_time_act=False, + time_embedding_norm="default", # default, scale_shift, ada_group, spatial + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + if groups_out is None: + groups_out = groups + + if self.time_embedding_norm == "ada_group": + self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + self.time_emb_proj = None + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + if self.time_embedding_norm == "ada_group": + self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + hidden_states = self.norm1(hidden_states, temb) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": + hidden_states = self.norm2(hidden_states, temb) + else: + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +# unet_rl.py +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, inputs): + intermediate_repr = self.conv1d(inputs) + intermediate_repr = rearrange_dims(intermediate_repr) + intermediate_repr = self.group_norm(intermediate_repr) + intermediate_repr = rearrange_dims(intermediate_repr) + output = self.mish(intermediate_repr) + return output + + +# unet_rl.py +class ResidualTemporalBlock1D(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, inputs, t): + """ + Args: + inputs : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(inputs) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(inputs) + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + ) + return output + + +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(tensor.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +class TemporalConvLayer(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__(self, in_dim, out_dim=None, dropout=0.0): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, hidden_states, num_frames=1): + hidden_states = ( + hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) + ) + + identity = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.conv3(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( + (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] + ) + return hidden_states diff --git a/diffusers/models/resnet_flax.py b/diffusers/models/resnet_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..9a391f4b947e74beda03f26e376141b2b3c21502 --- /dev/null +++ b/diffusers/models/resnet_flax.py @@ -0,0 +1,124 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class FlaxUpsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), # padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + # hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + in_channels: int + out_channels: int = None + dropout_prob: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(nn.swish(temb)) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual diff --git a/diffusers/models/t5_film_transformer.py b/diffusers/models/t5_film_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c41e656a9dbe81edafd5a2958d49ff28e84fd01 --- /dev/null +++ b/diffusers/models/t5_film_transformer.py @@ -0,0 +1,321 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from .attention_processor import Attention +from .embeddings import get_timestep_embedding +from .modeling_utils import ModelMixin + + +class T5FilmDecoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + input_dims: int = 128, + targets_length: int = 256, + max_decoder_noise_time: float = 2000.0, + d_model: int = 768, + num_layers: int = 12, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 2048, + dropout_rate: float = 0.1, + ): + super().__init__() + + self.conditioning_emb = nn.Sequential( + nn.Linear(d_model, d_model * 4, bias=False), + nn.SiLU(), + nn.Linear(d_model * 4, d_model * 4, bias=False), + nn.SiLU(), + ) + + self.position_encoding = nn.Embedding(targets_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) + + self.dropout = nn.Dropout(p=dropout_rate) + + self.decoders = nn.ModuleList() + for lyr_num in range(num_layers): + # FiLM conditional T5 decoder + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) + self.decoders.append(lyr) + + self.decoder_norm = T5LayerNorm(d_model) + + self.post_dropout = nn.Dropout(p=dropout_rate) + self.spec_out = nn.Linear(d_model, input_dims, bias=False) + + def encoder_decoder_mask(self, query_input, key_input): + mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) + return mask.unsqueeze(-3) + + def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): + batch, _, _ = decoder_input_tokens.shape + assert decoder_noise_time.shape == (batch,) + + # decoder_noise_time is in [0, 1), so rescale to expected timing range. + time_steps = get_timestep_embedding( + decoder_noise_time * self.config.max_decoder_noise_time, + embedding_dim=self.config.d_model, + max_period=self.config.max_decoder_noise_time, + ).to(dtype=self.dtype) + + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) + + assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) + + seq_length = decoder_input_tokens.shape[1] + + # If we want to use relative positions for audio context, we can just offset + # this sequence by the length of encodings_and_masks. + decoder_positions = torch.broadcast_to( + torch.arange(seq_length, device=decoder_input_tokens.device), + (batch, seq_length), + ) + + position_encodings = self.position_encoding(decoder_positions) + + inputs = self.continuous_inputs_projection(decoder_input_tokens) + inputs += position_encodings + y = self.dropout(inputs) + + # decoder: No padding present. + decoder_mask = torch.ones( + decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype + ) + + # Translate encoding masks to encoder-decoder masks. + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] + + # cross attend style: concat encodings + encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) + encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) + + for lyr in self.decoders: + y = lyr( + y, + conditioning_emb=conditioning_emb, + encoder_hidden_states=encoded, + encoder_attention_mask=encoder_decoder_mask, + )[0] + + y = self.decoder_norm(y) + y = self.post_dropout(y) + + spec_out = self.spec_out(y) + return spec_out + + +class DecoderLayer(nn.Module): + def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): + super().__init__() + self.layer = nn.ModuleList() + + # cond self attention: layer 0 + self.layer.append( + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) + ) + + # cross attention: layer 1 + self.layer.append( + T5LayerCrossAttention( + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + # Film Cond MLP + dropout: last layer + self.layer.append( + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) + ) + + def forward( + self, + hidden_states, + conditioning_emb=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + ): + hidden_states = self.layer[0]( + hidden_states, + conditioning_emb=conditioning_emb, + attention_mask=attention_mask, + ) + + if encoder_hidden_states is not None: + encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( + encoder_hidden_states.dtype + ) + + hidden_states = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_extended_attention_mask, + ) + + # Apply Film Conditional Feed Forward layer + hidden_states = self.layer[-1](hidden_states, conditioning_emb) + + return (hidden_states,) + + +class T5LayerSelfAttentionCond(nn.Module): + def __init__(self, d_model, d_kv, num_heads, dropout_rate): + super().__init__() + self.layer_norm = T5LayerNorm(d_model) + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states, + conditioning_emb=None, + attention_mask=None, + ): + # pre_self_attention_layer_norm + normed_hidden_states = self.layer_norm(hidden_states) + + if conditioning_emb is not None: + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) + + # Self-attention block + attention_output = self.attention(normed_hidden_states) + + hidden_states = hidden_states + self.dropout(attention_output) + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): + super().__init__() + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states, + key_value_states=None, + attention_mask=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + encoder_hidden_states=key_value_states, + attention_mask=attention_mask.squeeze(1), + ) + layer_output = hidden_states + self.dropout(attention_output) + return layer_output + + +class T5LayerFFCond(nn.Module): + def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, hidden_states, conditioning_emb=None): + forwarded_states = self.layer_norm(hidden_states) + if conditioning_emb is not None: + forwarded_states = self.film(forwarded_states, conditioning_emb) + + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, d_model, d_ff, dropout_rate): + super().__init__() + self.wi_0 = nn.Linear(d_model, d_ff, bias=False) + self.wi_1 = nn.Linear(d_model, d_ff, bias=False) + self.wo = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout_rate) + self.act = NewGELUActivation() + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class T5FiLMLayer(nn.Module): + """ + FiLM Layer + """ + + def __init__(self, in_features, out_features): + super().__init__() + self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) + + def forward(self, x, conditioning_emb): + emb = self.scale_bias(conditioning_emb) + scale, shift = torch.chunk(emb, 2, -1) + x = x * (1 + scale) + shift + return x diff --git a/diffusers/models/transformer_2d.py b/diffusers/models/transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..83da16838ae2248c31faada9cd5704d20500459c --- /dev/null +++ b/diffusers/models/transformer_2d.py @@ -0,0 +1,341 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..models.embeddings import ImagePositionalEmbeddings +from ..utils import BaseOutput, deprecate +from .attention import BasicTransformerBlock +from .embeddings import PatchEmbed +from .modeling_utils import ModelMixin + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffusers/models/transformer_temporal.py b/diffusers/models/transformer_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..cfafdb055bcfedc911b0a19d1e5da8089a18b215 --- /dev/null +++ b/diffusers/models/transformer_temporal.py @@ -0,0 +1,179 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .attention import BasicTransformerBlock +from .modeling_utils import ModelMixin + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, channel, num_frames) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/diffusers/models/unet_1d.py b/diffusers/models/unet_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..9b617388f3917c97e8aef39ec0f386eb2e4a1254 --- /dev/null +++ b/diffusers/models/unet_1d.py @@ -0,0 +1,255 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + The output of [`UNet1DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + extra_in_channels (`int`, *optional*, defaults to 0): + Number of additional channels to be added to the input of the first down block. Useful for cases where the + input data has more channels than what the model was initially designed for. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): + Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. + act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. + layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. + downsample_each_block (`int`, *optional*, defaults to `False`): + Experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + freq_shift: float = 0.0, + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, + block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, + ): + super().__init__() + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + + # down + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet1DOutput, Tuple]: + r""" + The [`UNet1DModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/diffusers/models/unet_1d_blocks.py b/diffusers/models/unet_1d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..84ae48e0f8c4f3da6132a02c3e89f7c976a2b150 --- /dev/null +++ b/diffusers/models/unet_1d_blocks.py @@ -0,0 +1,656 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from .activations import get_activation +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + self.final_conv1d_act = get_activation(act_fn) + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim, act_fn="mish"): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + get_activation(act_fn), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states): + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states, temb=None): + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels, n_head=1, dropout_rate=0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + + self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) + + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores, dim=-1) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, is_last=False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states): + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels, in_channels, out_channels=None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + + +def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) + return None diff --git a/diffusers/models/unet_2d.py b/diffusers/models/unet_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3b17acd3d829519465ec0d8daa41b16184aa70f2 --- /dev/null +++ b/diffusers/models/unet_2d.py @@ -0,0 +1,329 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + The output of [`UNet2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - + 1)`. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): + Tuple of downsample block types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class + conditioning with `class_embed_type` equal to `None`. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + act_fn: str = "silu", + attention_head_dim: Optional[int] = 8, + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + r""" + The [`UNet2DModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when doing class conditioning") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/diffusers/models/unet_2d_blocks.py b/diffusers/models/unet_2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..469e501b814b4673ab7f18378aecb348cebbfcdf --- /dev/null +++ b/diffusers/models/unet_2d_blocks.py @@ -0,0 +1,3182 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import is_torch_version, logging +from .attention import AdaGroupNorm +from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from .dual_transformer_2d import DualTransformer2DModel +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from .transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + downsample_padding=1, + downsample_type="conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == "conv": + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, upsample_size=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_downsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample=True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + upsample_type="conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + temb_channels=None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, temb=None): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + add_upsample=True, + temb_channels=None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample=True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim=1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/diffusers/models/unet_2d_blocks_flax.py b/diffusers/models/unet_2d_blocks_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1447570dda34b814bdc1660dfd37874fed0125 --- /dev/null +++ b/diffusers/models/unet_2d_blocks_flax.py @@ -0,0 +1,377 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + +from .attention_flax import FlaxTransformer2DModel +from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D + + +class FlaxCrossAttnDownBlock2D(nn.Module): + r""" + Cross Attention 2D Downsizing block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + add_downsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False + use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxTransformer2DModel( + in_channels=self.out_channels, + n_heads=self.num_attention_heads, + d_head=self.out_channels // self.num_attention_heads, + depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + r""" + Flax 2D downsizing block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + r""" + Cross Attention 2D Upsampling block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + add_upsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False + use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxTransformer2DModel( + in_channels=self.out_channels, + n_heads=self.num_attention_heads, + d_head=self.out_channels // self.num_attention_heads, + depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + r""" + Flax 2D upsampling block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + prev_output_channel (:obj:`int`): + Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + r""" + Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + use_linear_projection: bool = False + use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxTransformer2DModel( + in_channels=self.in_channels, + n_heads=self.num_attention_heads, + d_head=self.in_channels // self.num_attention_heads, + depth=1, + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + return hidden_states diff --git a/diffusers/models/unet_2d_condition.py b/diffusers/models/unet_2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..c76552fb3cd3445ebc93b54679380bbc1cc8b09a --- /dev/null +++ b/diffusers/models/unet_2d_condition.py @@ -0,0 +1,1101 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin +from ..utils import BaseOutput, logging +from .activations import get_activation +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from safetensors import safe_open +import safetensors + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = 256, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + size_cond=False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # self.size_cond = size_cond + # if self.size_cond: + # size_embed_dim = block_out_channels[0] * 2 + # if size_embed_dim % 2 != 0: + # raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {size_embed_dim}.") + # self.size_h_proj = GaussianFourierProjection( + # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + # ) + # self.size_w_proj = GaussianFourierProjection( + # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + # ) + # size_input_dim = size_embed_dim + # self.H_embedding = TimestepEmbedding( + # size_input_dim, + # size_embed_dim, + # act_fn=act_fn, + # post_act_fn=timestep_post_act, + # cond_proj_dim=time_cond_proj_dim, + # ) + # self.W_embedding = TimestepEmbedding( + # size_input_dim, + # size_embed_dim, + # act_fn=act_fn, + # post_act_fn=timestep_post_act, + # cond_proj_dim=time_cond_proj_dim, + # ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(addition_time_embed_dim * 8, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, weight=None, subfolder=None, size_cond=False, fp16=False): + import os + import json + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path: + from diffusers.utils import SAFETENSORS_WEIGHTS_NAME # diffusion_pytorch_model.safetensors + model_postfix = SAFETENSORS_WEIGHTS_NAME + if fp16: + model_postfix = "diffusion_pytorch_model.fp16.safetensors" + else: + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + model_postfix = WEIGHTS_NAME + + config["size_cond"] = size_cond + if size_cond: + # config.addition_embed_type = "time" + config["addition_embed_type"] = "time" + + model = cls.from_config(config) + + if weight is None: + print("Loading from pretrained SD") + model_file = os.path.join(pretrained_model_path, model_postfix) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path: + # state_dict = {} + # with safe_open(model_file, framework="pt", device=0) as f: + # for k in f.keys(): + # state_dict[k] = f.get_tensor(k) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + else: + state_dict = torch.load(model_file, map_location="cpu") + + for k, v in model.state_dict().items(): + # print(k) + if k not in state_dict.keys(): + state_dict.update({k: v}) + model.load_state_dict(state_dict) + + # conv_in_weights = model.conv_in.weight.clone() + # model.conv_in = nn.Conv2d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses + # # model.conv_in = InflatedConv3d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=3, padding=(1, 1)) + # # nn.Conv2d(4 + 4*n_poses, weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses + # with torch.no_grad(): + # model.conv_in.weight[:, :4] = conv_in_weights # original weights + # model.conv_in.weight[:, 4:] = torch.zeros(model.conv_in.weight[:, 4:].shape) # new weights initialized to zero + + return model + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + # H=None, + # W=None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + # if self.size_cond: + # h_emb = self.H_embedding(self.size_h_proj(H)) + # w_emb = self.W_embedding(self.size_w_proj(W)) + # emb = emb + torch.cat((h_emb, w_emb), dim=1) + + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "time": + time_ids = added_cond_kwargs.get("time_ids") + batch_size = time_ids.shape[0] + # time_embeds = self.add_time_proj(time_ids) + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + + add_embeds = time_embeds + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/diffusers/models/unet_2d_condition_flax.py b/diffusers/models/unet_2d_condition_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..de39bc75d2e392a423c9ea09e979b9f42d818dc1 --- /dev/null +++ b/diffusers/models/unet_2d_condition_flax.py @@ -0,0 +1,357 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .modeling_flax_utils import FlaxModelMixin +from .unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxCrossAttnUpBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, +) + + +@flax.struct.dataclass +class FlaxUNet2DConditionOutput(BaseOutput): + """ + The output of [`FlaxUNet2DConditionModel`]. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +@flax_register_to_config +class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `Tuple[int]`, *optional*): + The number of attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). + """ + + sample_size: int = 32 + in_channels: int = 4 + out_channels: int = 4 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + only_cross_attention: Union[bool, Tuple[bool]] = False + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: Union[int, Tuple[int]] = 8 + num_attention_heads: Optional[Union[int, Tuple[int]]] = None + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + use_memory_efficient_attention: bool = False + + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + + def setup(self): + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + if self.num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = self.num_attention_heads or self.attention_head_dim + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + num_attention_heads=num_attention_heads[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, + ) + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + num_attention_heads=reversed_num_attention_heads[i], + add_upsample=not is_final_block, + dropout=self.dropout, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxUNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + deterministic=not train, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + sample = jnp.transpose(sample, (0, 3, 1, 2)) + + if not return_dict: + return (sample,) + + return FlaxUNet2DConditionOutput(sample=sample) diff --git a/diffusers/models/unet_2d_condition_multi_branch.py b/diffusers/models/unet_2d_condition_multi_branch.py new file mode 100644 index 0000000000000000000000000000000000000000..c4784f0cc1d15181257d9ff2e454f69a2d5cc8a5 --- /dev/null +++ b/diffusers/models/unet_2d_condition_multi_branch.py @@ -0,0 +1,1205 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin +from ..utils import BaseOutput, logging +from .activations import get_activation +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from safetensors import safe_open +import safetensors +import copy + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = 256, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + size_cond=False, + cond_num=5, + copy_last_n_block=1, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # self.size_cond = size_cond + # if self.size_cond: + # size_embed_dim = block_out_channels[0] * 2 + # if size_embed_dim % 2 != 0: + # raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {size_embed_dim}.") + # self.size_h_proj = GaussianFourierProjection( + # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + # ) + # self.size_w_proj = GaussianFourierProjection( + # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + # ) + # size_input_dim = size_embed_dim + # self.H_embedding = TimestepEmbedding( + # size_input_dim, + # size_embed_dim, + # act_fn=act_fn, + # post_act_fn=timestep_post_act, + # cond_proj_dim=time_cond_proj_dim, + # ) + # self.W_embedding = TimestepEmbedding( + # size_input_dim, + # size_embed_dim, + # act_fn=act_fn, + # post_act_fn=timestep_post_act, + # cond_proj_dim=time_cond_proj_dim, + # ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(addition_time_embed_dim * 6, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + self.cond_num = cond_num + self.copy_last_n_block = copy_last_n_block + self.norm_num_groups = norm_num_groups + self.up_blocks_branch = nn.ModuleList([]) + for i in range(self.cond_num): + copy_block_list = nn.ModuleList([]) + for j in range(self.copy_last_n_block, 0, -1): + copy_block_list.append(copy.deepcopy(self.up_blocks[-j])) + self.up_blocks_branch.append(copy_block_list) + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + self.conv_norm_out_branch = nn.ModuleList([]) + self.conv_act_branch = nn.ModuleList([]) + for i in range(self.cond_num): + self.conv_norm_out_branch.append(copy.deepcopy(self.conv_norm_out)) + self.conv_act_branch.append(copy.deepcopy(self.conv_act)) + + else: + self.conv_norm_out = None + self.conv_act = None + self.conv_norm_out_branch = [None] * self.cond_num + self.conv_act_branch = [None] * self.cond_num + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + self.conv_out_branch = nn.ModuleList([]) + for i in range(self.cond_num): + self.conv_out_branch.append(copy.deepcopy(self.conv_out)) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, weight=None, subfolder=None, size_cond=False, fp16=False, cond_num=5, copy_last_n_block=1): + import os + import json + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path: + from diffusers.utils import SAFETENSORS_WEIGHTS_NAME # diffusion_pytorch_model.safetensors + model_postfix = SAFETENSORS_WEIGHTS_NAME + if fp16: + model_postfix = "diffusion_pytorch_model.fp16.safetensors" + else: + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + model_postfix = WEIGHTS_NAME + + config["size_cond"] = size_cond + if size_cond: + # config.addition_embed_type = "time" + config["addition_embed_type"] = "time" + + config["cond_num"] = cond_num + config["copy_last_n_block"] = copy_last_n_block + + model = cls.from_config(config) + + if weight is None: + print("Loading from pretrained SD") + model_file = os.path.join(pretrained_model_path, model_postfix) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path: + # state_dict = {} + # with safe_open(model_file, framework="pt", device=0) as f: + # for k in f.keys(): + # state_dict[k] = f.get_tensor(k) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + else: + state_dict = torch.load(model_file, map_location="cpu") + + for k, v in model.state_dict().items(): + # print(k) + if k not in state_dict.keys(): + state_dict.update({k: v}) + model.load_state_dict(state_dict) + + for i in range(model.cond_num): + for j in range(model.copy_last_n_block, 0, -1): + model.up_blocks_branch[i][-j] = copy.deepcopy(model.up_blocks[-j]) + if model.norm_num_groups is not None: + for i in range(model.cond_num): + model.conv_norm_out_branch[i] = copy.deepcopy(model.conv_norm_out) + model.conv_act_branch[i] = copy.deepcopy(model.conv_act) + for i in range(model.cond_num): + model.conv_out_branch[i] = copy.deepcopy(model.conv_out) + # conv_in_weights = model.conv_in.weight.clone() + # model.conv_in = nn.Conv2d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses + # # model.conv_in = InflatedConv3d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=3, padding=(1, 1)) + # # nn.Conv2d(4 + 4*n_poses, weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses + # with torch.no_grad(): + # model.conv_in.weight[:, :4] = conv_in_weights # original weights + # model.conv_in.weight[:, 4:] = torch.zeros(model.conv_in.weight[:, 4:].shape) # new weights initialized to zero + + return model + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + # H=None, + # W=None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + # if self.size_cond: + # h_emb = self.H_embedding(self.size_h_proj(H)) + # w_emb = self.W_embedding(self.size_w_proj(W)) + # emb = emb + torch.cat((h_emb, w_emb), dim=1) + + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "time": + time_ids = added_cond_kwargs.get("time_ids") + batch_size = time_ids.shape[0] + # time_embeds = self.add_time_proj(time_ids) + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + + add_embeds = time_embeds + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks[:-self.copy_last_n_block]): + is_final_block = False + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + sample_list = [] + for j in range(self.cond_num): + sample_copy = sample.clone() + down_block_res_samples_copy = down_block_res_samples + tuple() + for i, upsample_block in enumerate(self.up_blocks_branch[j]): + is_final_block = i == len(self.up_blocks_branch[j]) - 1 + + res_samples = down_block_res_samples_copy[-len(upsample_block.resnets) :] + down_block_res_samples_copy = down_block_res_samples_copy[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples_copy[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample_copy = upsample_block( + hidden_states=sample_copy, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample_copy = upsample_block( + hidden_states=sample_copy, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + sample_list.append(sample_copy) + + for i, upsample_block in enumerate(self.up_blocks[-self.copy_last_n_block:]): + is_final_block = is_final_block = i == len(self.up_blocks[-self.copy_last_n_block:]) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 6. post-process + output_list = [sample] + for i, sample in enumerate(sample_list): + if self.conv_norm_out_branch[i]: + sample = self.conv_norm_out_branch[i](sample) + sample = self.conv_act_branch[i](sample) + sample = self.conv_out_branch[i](sample) + output_list.append(sample) + sample = torch.cat(output_list, dim=1) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/diffusers/models/unet_2d_condition_multi_branch_downup.py b/diffusers/models/unet_2d_condition_multi_branch_downup.py new file mode 100644 index 0000000000000000000000000000000000000000..1062419d37c7704dadf791296536e0bf75d73b91 --- /dev/null +++ b/diffusers/models/unet_2d_condition_multi_branch_downup.py @@ -0,0 +1,1319 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin +from ..utils import BaseOutput, logging +from .activations import get_activation +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from .modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from safetensors import safe_open +import safetensors +import copy + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = 256, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + size_cond=False, + branch_num=5, + copy_last_n_block=1, + copy_first_n_block=1, + fusion: str = "sum", + off_wa=False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + self.branch_num = branch_num + self.copy_last_n_block = copy_last_n_block + self.copy_first_n_block = copy_first_n_block + self.norm_num_groups = norm_num_groups + self.fusion = fusion + + if self.fusion == "sum": + pass + elif self.fusion == "avg": + pass + elif self.fusion == "learn": + self.fusion_conv = nn.Conv2d(block_out_channels[self.copy_first_n_block - 1] * (self.branch_num + 1), block_out_channels[self.copy_first_n_block - 1], kernel_size=3, padding=1) + else: + assert False + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + self.conv_in_branch = nn.ModuleList([]) + for i in range(self.branch_num): + self.conv_in_branch.append(copy.deepcopy(self.conv_in)) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # self.size_cond = size_cond + # if self.size_cond: + # size_embed_dim = block_out_channels[0] * 2 + # if size_embed_dim % 2 != 0: + # raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {size_embed_dim}.") + # self.size_h_proj = GaussianFourierProjection( + # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + # ) + # self.size_w_proj = GaussianFourierProjection( + # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + # ) + # size_input_dim = size_embed_dim + # self.H_embedding = TimestepEmbedding( + # size_input_dim, + # size_embed_dim, + # act_fn=act_fn, + # post_act_fn=timestep_post_act, + # cond_proj_dim=time_cond_proj_dim, + # ) + # self.W_embedding = TimestepEmbedding( + # size_input_dim, + # size_embed_dim, + # act_fn=act_fn, + # post_act_fn=timestep_post_act, + # cond_proj_dim=time_cond_proj_dim, + # ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + if off_wa: + size_cond_num = 6 + else: + size_cond_num = 8 + self.add_embedding = TimestepEmbedding(addition_time_embed_dim * size_cond_num, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + self.down_blocks_branch = nn.ModuleList([]) + for i in range(self.branch_num): + copy_block_list = nn.ModuleList([]) + for j in range(self.copy_first_n_block): + copy_block_list.append(copy.deepcopy(self.down_blocks[j])) + self.down_blocks_branch.append(copy_block_list) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + self.up_blocks_branch = nn.ModuleList([]) + for i in range(self.branch_num): + copy_block_list = nn.ModuleList([]) + for j in range(self.copy_last_n_block, 0, -1): + copy_block_list.append(copy.deepcopy(self.up_blocks[-j])) + self.up_blocks_branch.append(copy_block_list) + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + self.conv_norm_out_branch = nn.ModuleList([]) + self.conv_act_branch = nn.ModuleList([]) + for i in range(self.branch_num): + self.conv_norm_out_branch.append(copy.deepcopy(self.conv_norm_out)) + self.conv_act_branch.append(copy.deepcopy(self.conv_act)) + + else: + self.conv_norm_out = None + self.conv_act = None + self.conv_norm_out_branch = [None] * self.branch_num + self.conv_act_branch = [None] * self.branch_num + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + self.conv_out_branch = nn.ModuleList([]) + for i in range(self.branch_num): + self.conv_out_branch.append(copy.deepcopy(self.conv_out)) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, weight=None, subfolder=None, size_cond=False, fp16=False, branch_num=5, copy_first_n_block=1, copy_last_n_block=1, fusion="sum", off_wa=False): + import os + import json + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path: + from diffusers.utils import SAFETENSORS_WEIGHTS_NAME # diffusion_pytorch_model.safetensors + model_postfix = SAFETENSORS_WEIGHTS_NAME + if fp16: + model_postfix = "diffusion_pytorch_model.fp16.safetensors" + else: + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + model_postfix = WEIGHTS_NAME + + config["size_cond"] = size_cond + if size_cond: + # config.addition_embed_type = "time" + config["addition_embed_type"] = "time" + + config["branch_num"] = branch_num + config["copy_first_n_block"] = copy_first_n_block + config["copy_last_n_block"] = copy_last_n_block + config["fusion"] = fusion + config["off_wa"] = off_wa + + model = cls.from_config(config) + + if weight is None: + print("Loading from pretrained SD") + model_file = os.path.join(pretrained_model_path, model_postfix) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path: + # state_dict = {} + # with safe_open(model_file, framework="pt", device=0) as f: + # for k in f.keys(): + # state_dict[k] = f.get_tensor(k) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + else: + state_dict = torch.load(model_file, map_location="cpu") + + for k, v in model.state_dict().items(): + # print(k) + if k not in state_dict.keys(): + state_dict.update({k: v}) + model.load_state_dict(state_dict) + + for i in range(model.branch_num): + for j in range(model.copy_first_n_block): + model.down_blocks_branch[i][j] = copy.deepcopy(model.down_blocks[j]) + for i in range(model.branch_num): + for j in range(model.copy_last_n_block, 0, -1): + model.up_blocks_branch[i][-j] = copy.deepcopy(model.up_blocks[-j]) + if model.norm_num_groups is not None: + for i in range(model.branch_num): + model.conv_norm_out_branch[i] = copy.deepcopy(model.conv_norm_out) + model.conv_act_branch[i] = copy.deepcopy(model.conv_act) + for i in range(model.branch_num): + model.conv_out_branch[i] = copy.deepcopy(model.conv_out) + # conv_in_weights = model.conv_in.weight.clone() + # model.conv_in = nn.Conv2d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses + # # model.conv_in = InflatedConv3d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=3, padding=(1, 1)) + # # nn.Conv2d(4 + 4*n_poses, weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses + # with torch.no_grad(): + # model.conv_in.weight[:, :4] = conv_in_weights # original weights + # model.conv_in.weight[:, 4:] = torch.zeros(model.conv_in.weight[:, 4:].shape) # new weights initialized to zero + + return model + + def forward( + self, + sample: torch.FloatTensor, + sample_noisy_list: List[torch.FloatTensor], + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + # H=None, + # W=None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + # if self.size_cond: + # h_emb = self.H_embedding(self.size_h_proj(H)) + # w_emb = self.W_embedding(self.size_w_proj(W)) + # emb = emb + torch.cat((h_emb, w_emb), dim=1) + + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "time": + time_ids = added_cond_kwargs.get("time_ids") + batch_size = time_ids.shape[0] + # time_embeds = self.add_time_proj(time_ids) + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + + add_embeds = time_embeds + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + sample_in_list = [] + for i in range(self.branch_num): + sample_in_list.append(self.conv_in_branch[i](sample_noisy_list[i])) + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks[:self.copy_first_n_block]: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + down_block_res_samples_list = [] + for i in range(self.branch_num): + down_block_res_samples_list.append((sample_in_list[i],)) + + for j in range(self.branch_num): + for i, downsample_block in enumerate(self.down_blocks_branch[j]): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample_in_list[j], res_samples = downsample_block( + hidden_states=sample_in_list[j], + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample_in_list[j], res_samples = downsample_block(hidden_states=sample_in_list[j], temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample_in_list[j] += down_block_additional_residuals.pop(0) + + down_block_res_samples_list[j] += res_samples + + # sample = (sample + sample_normal) / 2. + sample_list = [sample] + for i in range(self.branch_num): + sample_list.append(sample_in_list[i]) + + if self.fusion == "sum" or self.fusion == "avg": + stacked_tensor = torch.stack(sample_list, dim=0) + sample = torch.sum(stacked_tensor, dim=0) + if self.fusion == "avg": + sample = sample / (1 + self.branch_num) + elif self.fusion == "learn": + concat_tensor = torch.cat(sample_list, dim=1) + sample = self.fusion_conv(concat_tensor) + else: + assert False + + for downsample_block in self.down_blocks[self.copy_first_n_block:]: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + for i in range(self.branch_num): + down_block_res_samples_list[i] += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks[:-self.copy_last_n_block]): + is_final_block = False + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + for j in range(self.branch_num): + down_block_res_samples_list[j] = down_block_res_samples_list[j][: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + sample_list = [] + for j in range(self.branch_num): + sample_copy = sample.clone() + for i, upsample_block in enumerate(self.up_blocks_branch[j]): + is_final_block = i == len(self.up_blocks_branch[j]) - 1 + + res_samples = down_block_res_samples_list[j][-len(upsample_block.resnets) :] + down_block_res_samples_list[j] = down_block_res_samples_list[j][: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples_list[j][-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample_copy = upsample_block( + hidden_states=sample_copy, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample_copy = upsample_block( + hidden_states=sample_copy, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + sample_list.append(sample_copy) + + for i, upsample_block in enumerate(self.up_blocks[-self.copy_last_n_block:]): + is_final_block = i == len(self.up_blocks[-self.copy_last_n_block:]) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 6. post-process + output_list = [sample] + for i, sample_tmp in enumerate(sample_list): + if self.conv_norm_out_branch[i]: + sample_tmp = self.conv_norm_out_branch[i](sample_tmp) + sample_tmp = self.conv_act_branch[i](sample_tmp) + sample_tmp = self.conv_out_branch[i](sample_tmp) + output_list.append(sample_tmp) + sample = torch.cat(output_list, dim=1) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/diffusers/models/unet_3d_blocks.py b/diffusers/models/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5c393518e2ad8edf21069dfcd417392001569d --- /dev/null +++ b/diffusers/models/unet_3d_blocks.py @@ -0,0 +1,679 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from .transformer_2d import Transformer2DModel +from .transformer_temporal import TransformerTemporalModel + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/diffusers/models/unet_3d_condition.py b/diffusers/models/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2a8f1179ef9654b5234d63528468e59e371b10 --- /dev/null +++ b/diffusers/models/unet_3d_condition.py @@ -0,0 +1,627 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin +from ..utils import BaseOutput, logging +from .attention_processor import AttentionProcessor, AttnProcessor +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): The number of attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise NotImplementedError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def enable_forward_chunking(self, chunk_size=None, dim=0): + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + The [`UNet3DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + + Returns: + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/diffusers/models/vae.py b/diffusers/models/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..edd516dd380aa6f5888174bbd5f3df86be187feb --- /dev/null +++ b/diffusers/models/vae.py @@ -0,0 +1,441 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +from ..utils import BaseOutput, is_torch_version, randn_tensor +from .attention_processor import SpatialNorm +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class DecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + norm_type="group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, z, latent_embeds=None): + sample = z + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype + ) + x = self.mean + self.std * sample + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean diff --git a/diffusers/models/vae_flax.py b/diffusers/models/vae_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f5b1d0e399ab8e58d81d396d19b6f082192f5a --- /dev/null +++ b/diffusers/models/vae_flax.py @@ -0,0 +1,869 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers + +import math +from functools import partial +from typing import Tuple + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .modeling_flax_utils import FlaxModelMixin + + +@flax.struct.dataclass +class FlaxDecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The `dtype` of the parameters. + """ + + sample: jnp.ndarray + + +@flax.struct.dataclass +class FlaxAutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`FlaxDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. + `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "FlaxDiagonalGaussianDistribution" + + +class FlaxUpsample2D(nn.Module): + """ + Flax implementation of 2D Upsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + """ + Flax implementation of 2D Downsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + """ + Flax implementation of 2D Resnet Block. + + Args: + in_channels (`int`): + Input channels + out_channels (`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for group norm. + use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): + Whether to use `nin_shortcut`. This activates a new layer inside ResNet block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int = None + dropout: float = 0.0 + groups: int = 32 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) + self.dropout_layer = nn.Dropout(self.dropout) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class FlaxAttentionBlock(nn.Module): + r""" + Flax Convolutional based multi-head attention block for diffusion-based VAE. + + Parameters: + channels (:obj:`int`): + Input channels + num_head_channels (:obj:`int`, *optional*, defaults to `None`): + Number of attention heads + num_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for group norm + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ + channels: int + num_head_channels: int = None + num_groups: int = 32 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 + + dense = partial(nn.Dense, self.channels, dtype=self.dtype) + + self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) + self.query, self.key, self.value = dense(), dense(), dense() + self.proj_attn = dense() + + def transpose_for_scores(self, projection): + new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) + new_projection = projection.reshape(new_projection_shape) + # (B, T, H, D) -> (B, H, T, D) + new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) + return new_projection + + def __call__(self, hidden_states): + residual = hidden_states + batch, height, width, channels = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.reshape((batch, height * width, channels)) + + query = self.query(hidden_states) + key = self.key(hidden_states) + value = self.value(hidden_states) + + # transpose + query = self.transpose_for_scores(query) + key = self.transpose_for_scores(key) + value = self.transpose_for_scores(value) + + # compute attentions + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) + attn_weights = nn.softmax(attn_weights, axis=-1) + + # attend to values + hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) + + hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) + new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) + hidden_states = hidden_states.reshape(new_hidden_states_shape) + + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.reshape((batch, height, width, channels)) + hidden_states = hidden_states + residual + return hidden_states + + +class FlaxDownEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet block group norm + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + resnet_groups: int = 32 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout=self.dropout, + groups=self.resnet_groups, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, deterministic=deterministic) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUpDecoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Decoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet block group norm + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + resnet_groups: int = 32 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout=self.dropout, + groups=self.resnet_groups, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2D(nn.Module): + r""" + Flax Unet Mid-Block module. + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to `32`): + The number of groups to use for the Resnet and Attention block group norm + num_attention_heads (:obj:`int`, *optional*, defaults to `1`): + Number of attention heads for each attention block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + resnet_groups: int = 32 + num_attention_heads: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout=self.dropout, + groups=resnet_groups, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxAttentionBlock( + channels=self.in_channels, + num_head_channels=self.num_attention_heads, + num_groups=resnet_groups, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout=self.dropout, + groups=resnet_groups, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, deterministic=deterministic) + + return hidden_states + + +class FlaxEncoder(nn.Module): + r""" + Flax Implementation of VAE Encoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int = 3 + out_channels: int = 3 + down_block_types: Tuple[str] = ("DownEncoderBlock2D",) + block_out_channels: Tuple[int] = (64,) + layers_per_block: int = 2 + norm_num_groups: int = 32 + act_fn: str = "silu" + double_z: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + block_out_channels = self.block_out_channels + # in + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # downsampling + down_blocks = [] + output_channel = block_out_channels[0] + for i, _ in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = FlaxDownEncoderBlock2D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + resnet_groups=self.norm_num_groups, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # middle + self.mid_block = FlaxUNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + num_attention_heads=None, + dtype=self.dtype, + ) + + # end + conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) + self.conv_out = nn.Conv( + conv_out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, deterministic: bool = True): + # in + sample = self.conv_in(sample) + + # downsampling + for block in self.down_blocks: + sample = block(sample, deterministic=deterministic) + + # middle + sample = self.mid_block(sample, deterministic=deterministic) + + # end + sample = self.conv_norm_out(sample) + sample = nn.swish(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxDecoder(nn.Module): + r""" + Flax Implementation of VAE Decoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ + in_channels: int = 3 + out_channels: int = 3 + up_block_types: Tuple[str] = ("UpDecoderBlock2D",) + block_out_channels: int = (64,) + layers_per_block: int = 2 + norm_num_groups: int = 32 + act_fn: str = "silu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + block_out_channels = self.block_out_channels + + # z to block_in + self.conv_in = nn.Conv( + block_out_channels[-1], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # middle + self.mid_block = FlaxUNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + num_attention_heads=None, + dtype=self.dtype, + ) + + # upsampling + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + up_blocks = [] + for i, _ in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = FlaxUpDecoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + resnet_groups=self.norm_num_groups, + add_upsample=not is_final_block, + dtype=self.dtype, + ) + up_blocks.append(up_block) + prev_output_channel = output_channel + + self.up_blocks = up_blocks + + # end + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, deterministic: bool = True): + # z to block_in + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample, deterministic=deterministic) + + # upsampling + for block in self.up_blocks: + sample = block(sample, deterministic=deterministic) + + sample = self.conv_norm_out(sample) + sample = nn.swish(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxDiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + # Last axis to account for channels-last + self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) + self.logvar = jnp.clip(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = jnp.exp(0.5 * self.logvar) + self.var = jnp.exp(self.logvar) + if self.deterministic: + self.var = self.std = jnp.zeros_like(self.mean) + + def sample(self, key): + return self.mean + self.std * jax.random.normal(key, self.mean.shape) + + def kl(self, other=None): + if self.deterministic: + return jnp.array([0.0]) + + if other is None: + return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) + + return 0.5 * jnp.sum( + jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, + axis=[1, 2, 3], + ) + + def nll(self, sample, axis=[1, 2, 3]): + if self.deterministic: + return jnp.array([0.0]) + + logtwopi = jnp.log(2.0 * jnp.pi) + return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) + + def mode(self): + return self.mean + + +@flax_register_to_config +class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + Flax implementation of a VAE model with KL loss for decoding latent representations. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): + Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + Tuple of upsample block types. + block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): + Number of ResNet layer for each block. + act_fn (`str`, *optional*, defaults to `silu`): + The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): + Number of channels in the latent space. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups for normalization. + sample_size (`int`, *optional*, defaults to 32): + Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The `dtype` of the parameters. + """ + in_channels: int = 3 + out_channels: int = 3 + down_block_types: Tuple[str] = ("DownEncoderBlock2D",) + up_block_types: Tuple[str] = ("UpDecoderBlock2D",) + block_out_channels: Tuple[int] = (64,) + layers_per_block: int = 1 + act_fn: str = "silu" + latent_channels: int = 4 + norm_num_groups: int = 32 + sample_size: int = 32 + scaling_factor: float = 0.18215 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.encoder = FlaxEncoder( + in_channels=self.config.in_channels, + out_channels=self.config.latent_channels, + down_block_types=self.config.down_block_types, + block_out_channels=self.config.block_out_channels, + layers_per_block=self.config.layers_per_block, + act_fn=self.config.act_fn, + norm_num_groups=self.config.norm_num_groups, + double_z=True, + dtype=self.dtype, + ) + self.decoder = FlaxDecoder( + in_channels=self.config.latent_channels, + out_channels=self.config.out_channels, + up_block_types=self.config.up_block_types, + block_out_channels=self.config.block_out_channels, + layers_per_block=self.config.layers_per_block, + norm_num_groups=self.config.norm_num_groups, + act_fn=self.config.act_fn, + dtype=self.dtype, + ) + self.quant_conv = nn.Conv( + 2 * self.config.latent_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + self.post_quant_conv = nn.Conv( + self.config.latent_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + + params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) + rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} + + return self.init(rngs, sample)["params"] + + def encode(self, sample, deterministic: bool = True, return_dict: bool = True): + sample = jnp.transpose(sample, (0, 2, 3, 1)) + + hidden_states = self.encoder(sample, deterministic=deterministic) + moments = self.quant_conv(hidden_states) + posterior = FlaxDiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return FlaxAutoencoderKLOutput(latent_dist=posterior) + + def decode(self, latents, deterministic: bool = True, return_dict: bool = True): + if latents.shape[-1] != self.config.latent_channels: + latents = jnp.transpose(latents, (0, 2, 3, 1)) + + hidden_states = self.post_quant_conv(latents) + hidden_states = self.decoder(hidden_states, deterministic=deterministic) + + hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) + + if not return_dict: + return (hidden_states,) + + return FlaxDecoderOutput(sample=hidden_states) + + def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): + posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) + if sample_posterior: + rng = self.make_rng("gaussian") + hidden_states = posterior.latent_dist.sample(rng) + else: + hidden_states = posterior.latent_dist.mode() + + sample = self.decode(hidden_states, return_dict=return_dict).sample + + if not return_dict: + return (sample,) + + return FlaxDecoderOutput(sample=sample) diff --git a/diffusers/models/vq_model.py b/diffusers/models/vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..687449e8c7557473c0af994b30ef4c7dfba9718c --- /dev/null +++ b/diffusers/models/vq_model.py @@ -0,0 +1,167 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, apply_forward_hook +from .modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The encoded output sample from the last layer of the model. + """ + + latents: torch.FloatTensor + + +class VQModel(ModelMixin, ConfigMixin): + r""" + A VQ-VAE model for decoding latent representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, + scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_type=norm_type, + ) + + @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + The [`VQModel`] forward method. + + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` + is returned. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/optimization.py b/diffusers/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..46e6125a0f5565b80ced30dfc147f8168ef35a5c --- /dev/null +++ b/diffusers/optimization.py @@ -0,0 +1,354 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for diffusion models.""" + +import math +from enum import Enum +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + PIECEWISE_CONSTANT = "piecewise_constant" + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + rules_dict = {} + rule_list = step_rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr_multiple = float(rule_list[-1]) + + def create_rules_function(rules_dict, last_lr_multiple): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr_multiple + + return rule_func + + rules_func = create_rules_function(rules_dict, last_lr_multiple) + + return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + step_rules: Optional[str] = None, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, + last_epoch: int = -1, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + step_rules (`str`, *optional*): + A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer, last_epoch=last_epoch) + + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + last_epoch=last_epoch, + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + power=power, + last_epoch=last_epoch, + ) + + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch + ) diff --git a/diffusers/pipeline_utils.py b/diffusers/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..87709d5f616cdfb195ed4527e4b630a86136c29c --- /dev/null +++ b/diffusers/pipeline_utils.py @@ -0,0 +1,29 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works + +from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 +from .utils import deprecate + + +deprecate( + "pipelines_utils", + "0.22.0", + "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.", + standard_warn=False, + stacklevel=3, +) diff --git a/diffusers/pipelines/README.md b/diffusers/pipelines/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7562040596e9028ed56431817f42f4379ecf3435 --- /dev/null +++ b/diffusers/pipelines/README.md @@ -0,0 +1,171 @@ +# 🧨 Diffusers Pipelines + +Pipelines provide a simple way to run state-of-the-art diffusion models in inference. +Most diffusion systems consist of multiple independently-trained models and highly adaptable scheduler +components - all of which are needed to have a functioning end-to-end diffusion system. + +As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models: +- [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392) +- [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12) +- [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel) +- a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py), +- a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor), +- as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py). +All of these components are necessary to run stable diffusion in inference even though they were trained +or created independently from each other. + +To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API. +More specifically, we strive to provide pipelines that +- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), +- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section), +- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)), +- 4. can easily be contributed by the community (see the [Contribution](#contribution) section). + +**Note** that pipelines do not (and should not) offer any training functionality. +If you are looking for *official* training examples, please have a look at [examples](https://github.com/huggingface/diffusers/tree/main/examples). + + +## Pipelines Summary + +The following table summarizes all officially supported pipelines, their corresponding paper, and if +available a colab notebook to directly try them out. + +| Pipeline | Source | Tasks | Colab +|-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:| +| [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* | +| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* | +| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* | +| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* | +| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | +| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | +| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | + +**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. +However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below. + +## Pipelines API + +Diffusion models often consist of multiple independently-trained models or other previously existing components. + + +Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one. +During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality: + +- [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.* +"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be +loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`. +- [`save_pretrained`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L90) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`. +In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated +from the local path. +- [`to`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L118) which accepts a `string` or `torch.device` to move all models that are of type `torch.nn.Module` to the passed device. The behavior is fully analogous to [PyTorch's `to` method](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to). +- [`__call__`] method to use the pipeline in inference. `__call__` defines inference logic of the pipeline and should ideally encompass all aspects of it, from pre-processing to forwarding tensors to the different models and schedulers, as well as post-processing. The API of the `__call__` method can strongly vary from pipeline to pipeline. *E.g.* a text-to-image pipeline, such as [`StableDiffusionPipeline`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) should accept among other things the text prompt to generate the image. A pure image generation pipeline, such as [DDPMPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/ddpm) on the other hand can be run without providing any inputs. To better understand what inputs can be adapted for +each pipeline, one should look directly into the respective pipeline. + +**Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should +not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community) + +## Contribution + +We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire +all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. + +- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline. +- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and +use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most +logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method. +- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better. +- **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*. + +## Examples + +### Text-to-Image generation with Stable Diffusion + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler + +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### Image-to-Image text-guided generation with Stable Diffusion + +The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. + +```python +import requests +from PIL import Image +from io import BytesIO + +from diffusers import StableDiffusionImg2ImgPipeline + +# load the pipeline +device = "cuda" +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, +).to(device) + +# let's download an initial image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) + +prompt = "A fantasy landscape, trending on artstation" + +images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + +images[0].save("fantasy_landscape.png") +``` +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) + +### Tweak prompts reusing seeds and latents + +You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb). + + +### In-painting using Stable Diffusion + +The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt. + +```python +import PIL +import requests +import torch +from io import BytesIO + +from diffusers import StableDiffusionInpaintPipeline + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +``` + +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/diffusers/pipelines/__init__.py b/diffusers/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..802ae4f5bc9490aa247812d65a4dc17606747fea --- /dev/null +++ b/diffusers/pipelines/__init__.py @@ -0,0 +1,188 @@ +from ..utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_invisible_watermark_available, + is_k_diffusion_available, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_torch_available, + is_transformers_available, +) + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 +else: + from .consistency_models import ConsistencyModelPipeline + from .dance_diffusion import DanceDiffusionPipeline + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .dit import DiTPipeline + from .latent_diffusion import LDMSuperResolutionPipeline + from .latent_diffusion_uncond import LDMPipeline + from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 +else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline + from .audioldm import AudioLDMPipeline + from .controlnet import ( + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + ) + from .deepfloyd_if import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ) + from .kandinsky import ( + KandinskyImg2ImgPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + ) + from .kandinsky2_2 import ( + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + ) + from .latent_diffusion import LDMTextToImagePipeline + from .paint_by_example import PaintByExamplePipeline + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_diffusion import ( + CycleDiffusionPipeline, + StableDiffusionAttendAndExcitePipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionLDM3DPipeline, + StableDiffusionModelEditingPipeline, + StableDiffusionPanoramaPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPipeline, + StableDiffusionPix2PixZeroPipeline, + StableDiffusionSAGPipeline, + StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .t2i_adapter import StableDiffusionAdapterPipeline + from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + + +try: + if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 +else: + from .controlnet import StableDiffusionXLControlNetPipeline + from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + ) + +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_onnx_objects import * # noqa F403 +else: + from .onnx_utils import OnnxRuntimeModel + +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 +else: + from .stable_diffusion import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 +else: + from .stable_diffusion import StableDiffusionKDiffusionPipeline + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 +else: + from .pipeline_flax_utils import FlaxDiffusionPipeline + + +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 +else: + from .controlnet import FlaxStableDiffusionControlNetPipeline + from .stable_diffusion import ( + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + ) +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 +else: + from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline diff --git a/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c33b6696d01f832459cffd982a75a8e8018cc064 Binary files /dev/null and b/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb2e953d66aeb81ecf4ec660b54414ba00852e0 Binary files /dev/null and b/diffusers/pipelines/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..193100f53deab1189c140af56c2a982a0ac1d4b9 Binary files /dev/null and b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc differ diff --git a/diffusers/pipelines/__pycache__/pipeline_utils.cpython-38.pyc b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c4aeee1e5f0cf7109de40874e96de6b455eb7f Binary files /dev/null and b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-38.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__init__.py b/diffusers/pipelines/alt_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dab2d8db1045ef27ff5d2234951c1488f547401b --- /dev/null +++ b/diffusers/pipelines/alt_diffusion/__init__.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .modeling_roberta_series import RobertaSeriesModelWithTransformation + from .pipeline_alt_diffusion import AltDiffusionPipeline + from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e61607e3f39a7092ce2a1a3ab49b2b9f96aad7 Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b74bf23dae3832663a8a3b35377b539e45879acc Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b5ba958c7db1c78799ecd0cc671ea60255bcbe Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d2ef8a8ef2ece127bad886d836132a35d459127 Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-38.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23733118c8bbd0b8b46100cf1ab1bb847266cd2 Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55f78c2677e7e6ce3a99850e3d3942ba8f16fb60 Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0e33f8242771561ac1c510f427890eceaf433cc Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c7c463e041b7d207e30e37555e5007570342de Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py new file mode 100644 index 0000000000000000000000000000000000000000..f73ef15d7de7948a9cbad246027ca71f4a6db198 --- /dev/null +++ b/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..507c082d93638b594e8be21607eae05f471aecfa --- /dev/null +++ b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -0,0 +1,732 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPImageProcessor, XLMRobertaTokenizer + +from diffusers.utils import is_accelerate_available, is_accelerate_version + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AltDiffusionPipeline + + >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" + >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker +class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`RobertaSeriesModelWithTransformation`]): + Frozen text-encoder. Alt Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`XLMRobertaTokenizer`): + Tokenizer of class + [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f297122ba4c93127fccf4b9c1ace9a90092bbd --- /dev/null +++ b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -0,0 +1,758 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, XLMRobertaTokenizer + +from diffusers.utils import is_accelerate_available, is_accelerate_version + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import AltDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "BAAI/AltDiffusion-m9" + >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> # "A fantasy landscape, trending on artstation" + >>> prompt = "幻想风景, artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("幻想风景.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker +class AltDiffusionImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image to image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`RobertaSeriesModelWithTransformation`]): + Frozen text-encoder. Alt Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`XLMRobertaTokenizer`): + Tokenizer of class + [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective" + f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/audio_diffusion/__init__.py b/diffusers/pipelines/audio_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58554c45ea52b9897293217652db36fdace7549f --- /dev/null +++ b/diffusers/pipelines/audio_diffusion/__init__.py @@ -0,0 +1,2 @@ +from .mel import Mel +from .pipeline_audio_diffusion import AudioDiffusionPipeline diff --git a/diffusers/pipelines/audio_diffusion/mel.py b/diffusers/pipelines/audio_diffusion/mel.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf28fd25a5a5d39416eaf6bfd76b7f6945f4b19 --- /dev/null +++ b/diffusers/pipelines/audio_diffusion/mel.py @@ -0,0 +1,160 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np # noqa: E402 + +from ...configuration_utils import ConfigMixin, register_to_config +from ...schedulers.scheduling_utils import SchedulerMixin + + +try: + import librosa # noqa: E402 + + _librosa_can_be_imported = True + _import_error = "" +except Exception as e: + _librosa_can_be_imported = False + _import_error = ( + f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." + ) + + +from PIL import Image # noqa: E402 + + +class Mel(ConfigMixin, SchedulerMixin): + """ + Parameters: + x_res (`int`): x resolution of spectrogram (time) + y_res (`int`): y resolution of spectrogram (frequency bins) + sample_rate (`int`): sample rate of audio + n_fft (`int`): number of Fast Fourier Transforms + hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res) + top_db (`int`): loudest in decibels + n_iter (`int`): number of iterations for Griffin Linn mel inversion + """ + + config_name = "mel_config.json" + + @register_to_config + def __init__( + self, + x_res: int = 256, + y_res: int = 256, + sample_rate: int = 22050, + n_fft: int = 2048, + hop_length: int = 512, + top_db: int = 80, + n_iter: int = 32, + ): + self.hop_length = hop_length + self.sr = sample_rate + self.n_fft = n_fft + self.top_db = top_db + self.n_iter = n_iter + self.set_resolution(x_res, y_res) + self.audio = None + + if not _librosa_can_be_imported: + raise ValueError(_import_error) + + def set_resolution(self, x_res: int, y_res: int): + """Set resolution. + + Args: + x_res (`int`): x resolution of spectrogram (time) + y_res (`int`): y resolution of spectrogram (frequency bins) + """ + self.x_res = x_res + self.y_res = y_res + self.n_mels = self.y_res + self.slice_size = self.x_res * self.hop_length - 1 + + def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): + """Load audio. + + Args: + audio_file (`str`): must be a file on disk due to Librosa limitation or + raw_audio (`np.ndarray`): audio as numpy array + """ + if audio_file is not None: + self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) + else: + self.audio = raw_audio + + # Pad with silence if necessary. + if len(self.audio) < self.x_res * self.hop_length: + self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) + + def get_number_of_slices(self) -> int: + """Get number of slices in audio. + + Returns: + `int`: number of spectograms audio can be sliced into + """ + return len(self.audio) // self.slice_size + + def get_audio_slice(self, slice: int = 0) -> np.ndarray: + """Get slice of audio. + + Args: + slice (`int`): slice number of audio (out of get_number_of_slices()) + + Returns: + `np.ndarray`: audio as numpy array + """ + return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] + + def get_sample_rate(self) -> int: + """Get sample rate: + + Returns: + `int`: sample rate of audio + """ + return self.sr + + def audio_slice_to_image(self, slice: int) -> Image.Image: + """Convert slice of audio to spectrogram. + + Args: + slice (`int`): slice number of audio to convert (out of get_number_of_slices()) + + Returns: + `PIL Image`: grayscale image of x_res x y_res + """ + S = librosa.feature.melspectrogram( + y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels + ) + log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) + bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) + image = Image.fromarray(bytedata) + return image + + def image_to_audio(self, image: Image.Image) -> np.ndarray: + """Converts spectrogram to audio. + + Args: + image (`PIL Image`): x_res x y_res grayscale image + + Returns: + audio (`np.ndarray`): raw audio + """ + bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) + log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db + S = librosa.db_to_power(log_S) + audio = librosa.feature.inverse.mel_to_audio( + S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter + ) + return audio diff --git a/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..629a2e7d32ca307c91b55359ccd93c8fb12884ff --- /dev/null +++ b/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py @@ -0,0 +1,249 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from math import acos, sin +from typing import List, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, DDPMScheduler +from ...utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput +from .mel import Mel + + +class AudioDiffusionPipeline(DiffusionPipeline): + """ + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None + unet ([`UNet2DConditionModel`]): UNET model + mel ([`Mel`]): transform audio <-> spectrogram + scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler + """ + + _optional_components = ["vqvae"] + + def __init__( + self, + vqvae: AutoencoderKL, + unet: UNet2DConditionModel, + mel: Mel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + ): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) + + def get_default_steps(self) -> int: + """Returns default number of steps recommended for inference + + Returns: + `int`: number of steps + """ + return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + audio_file: str = None, + raw_audio: np.ndarray = None, + slice: int = 0, + start_step: int = 0, + steps: int = None, + generator: torch.Generator = None, + mask_start_secs: float = 0, + mask_end_secs: float = 0, + step_generator: torch.Generator = None, + eta: float = 0, + noise: torch.Tensor = None, + encoding: torch.Tensor = None, + return_dict=True, + ) -> Union[ + Union[AudioPipelineOutput, ImagePipelineOutput], + Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]], + ]: + """Generate random mel spectrogram from audio input and convert to audio. + + Args: + batch_size (`int`): number of samples to generate + audio_file (`str`): must be a file on disk due to Librosa limitation or + raw_audio (`np.ndarray`): audio as numpy array + slice (`int`): slice number of audio to convert + start_step (int): step to start from + steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM) + generator (`torch.Generator`): random number generator or None + mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start + mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end + step_generator (`torch.Generator`): random number generator used to de-noise or None + eta (`float`): parameter between 0 and 1 used with DDIM scheduler + noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None + encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim) + return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple + + Returns: + `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios + """ + + steps = steps or self.get_default_steps() + self.scheduler.set_timesteps(steps) + step_generator = step_generator or generator + # For backwards compatibility + if type(self.unet.config.sample_size) == int: + self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) + if noise is None: + noise = randn_tensor( + ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size[0], + self.unet.config.sample_size[1], + ), + generator=generator, + device=self.device, + ) + images = noise + mask = None + + if audio_file is not None or raw_audio is not None: + self.mel.load_audio(audio_file, raw_audio) + input_image = self.mel.audio_slice_to_image(slice) + input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( + (input_image.height, input_image.width) + ) + input_image = (input_image / 255) * 2 - 1 + input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) + + if self.vqvae is not None: + input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( + generator=generator + )[0] + input_images = self.vqvae.config.scaling_factor * input_images + + if start_step > 0: + images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) + + pixels_per_second = ( + self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length + ) + mask_start = int(mask_start_secs * pixels_per_second) + mask_end = int(mask_end_secs * pixels_per_second) + mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) + + for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): + if isinstance(self.unet, UNet2DConditionModel): + model_output = self.unet(images, t, encoding)["sample"] + else: + model_output = self.unet(images, t)["sample"] + + if isinstance(self.scheduler, DDIMScheduler): + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + eta=eta, + generator=step_generator, + )["prev_sample"] + else: + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + generator=step_generator, + )["prev_sample"] + + if mask is not None: + if mask_start > 0: + images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] + if mask_end > 0: + images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] + + if self.vqvae is not None: + # 0.18215 was scaling factor used in training to ensure unit variance + images = 1 / self.vqvae.config.scaling_factor * images + images = self.vqvae.decode(images)["sample"] + + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).numpy() + images = (images * 255).round().astype("uint8") + images = list( + (Image.fromarray(_[:, :, 0]) for _ in images) + if images.shape[3] == 1 + else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) + ) + + audios = [self.mel.image_to_audio(_) for _ in images] + if not return_dict: + return images, (self.mel.get_sample_rate(), audios) + + return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) + + @torch.no_grad() + def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: + """Reverse step process: recover noisy image from generated image. + + Args: + images (`List[PIL Image]`): list of images to encode + steps (`int`): number of encoding steps to perform (defaults to 50) + + Returns: + `np.ndarray`: noise tensor of shape (batch_size, 1, height, width) + """ + + # Only works with DDIM as this method is deterministic + assert isinstance(self.scheduler, DDIMScheduler) + self.scheduler.set_timesteps(steps) + sample = np.array( + [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] + ) + sample = (sample / 255) * 2 - 1 + sample = torch.Tensor(sample).to(self.device) + + for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.scheduler.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + model_output = self.unet(sample, t)["sample"] + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output + sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) + sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output + + return sample + + @staticmethod + def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: + """Spherical Linear intERPolation + + Args: + x0 (`torch.Tensor`): first tensor to interpolate between + x1 (`torch.Tensor`): seconds tensor to interpolate between + alpha (`float`): interpolation between 0 and 1 + + Returns: + `torch.Tensor`: interpolated tensor + """ + + theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) + return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) diff --git a/diffusers/pipelines/audioldm/__init__.py b/diffusers/pipelines/audioldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddef6c3f3253afd1f59c14b685a5d14d7622150 --- /dev/null +++ b/diffusers/pipelines/audioldm/__init__.py @@ -0,0 +1,17 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) +else: + from .pipeline_audioldm import AudioLDMPipeline diff --git a/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a35ebabb57a682b2c2d4a7d348eaf97da81f70d8 Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb97f4b8a55a2041615dba295162342566e8839 Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a46b85fa2e2511efaa2b8c0407f29f09976ecdf Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc differ diff --git a/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-38.pyc b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08672343a352fa7fb50609eb36eab4dd29340a08 Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-38.pyc differ diff --git a/diffusers/pipelines/audioldm/pipeline_audioldm.py b/diffusers/pipelines/audioldm/pipeline_audioldm.py new file mode 100644 index 0000000000000000000000000000000000000000..6da8e809103e3a2670ddec96e6612122a70f4b80 --- /dev/null +++ b/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -0,0 +1,566 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AudioLDMPipeline + + >>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A hammer hitting a wooden surface" + >>> audio = pipe(prompt).audios[0] + ``` +""" + + +class AudioLDMPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using AudioLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations. + text_encoder ([`ClapTextModelWithProjection`]): + Frozen text-encoder. AudioLDM uses the text portion of + [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection), + specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + Tokenizer of class + [RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer). + unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`SpeechT5HifiGan`]): + Vocoder of class + [SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan). + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapTextModelWithProjection, + tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + prompt_embeds = F.normalize(prompt_embeds, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + return mel_spectrogram + + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + self.vocoder.config.model_in_dim // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_length_in_s: Optional[float] = None, + num_inference_steps: int = 10, + guidance_scale: float = 2.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + output_type: Optional[str] = "np", + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`. + instead. + audio_length_in_s (`int`, *optional*, defaults to 5.12): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 10): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`, + usually at the expense of lower sound quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate image. Choose between: + - `"np"`: Return Numpy `np.ndarray` objects. + - `"pt"`: Return PyTorch `torch.Tensor` objects. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated audios. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + mel_spectrogram = self.decode_latents(latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/pipelines/consistency_models/__init__.py b/diffusers/pipelines/consistency_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd78ddb3aae232a734bd911e92d8c9a07019945d --- /dev/null +++ b/diffusers/pipelines/consistency_models/__init__.py @@ -0,0 +1 @@ +from .pipeline_consistency_models import ConsistencyModelPipeline diff --git a/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71116b6dd095d20b3524478bf2df6b1c0fa04342 Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad3e0b8225c672098a46a7f4c10d453a7f27c727 Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-310.pyc b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d22abe15e344025bd93db9e18279506d0ff052aa Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-310.pyc differ diff --git a/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-38.pyc b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2614bd0e8fce455ba6e3009181c16180e2984b81 Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-38.pyc differ diff --git a/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/diffusers/pipelines/consistency_models/pipeline_consistency_models.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4af7afe5adf99aef45f8ff0578cefa39549182 --- /dev/null +++ b/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -0,0 +1,293 @@ +from typing import Callable, List, Optional, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import CMStochasticIterativeScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import ConsistencyModelPipeline + + >>> device = "cuda" + >>> # Load the cd_imagenet64_l2 checkpoint. + >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" + >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe.to(device) + + >>> # Onestep Sampling + >>> image = pipe(num_inference_steps=1).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample.png") + + >>> # Onestep sampling, class-conditional image generation + >>> # ImageNet-64 class label 145 corresponds to king penguins + >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") + + >>> # Multistep sampling, class-conditional image generation + >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: + >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 + >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") + ``` +""" + + +class ConsistencyModelPipeline(DiffusionPipeline): + r""" + Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 + + Args: + unet ([`UNet2DModel`]): + Unconditional or class-conditional U-Net architecture to denoise image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible + with [`CMStochasticIterativeScheduler`]. + """ + + def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + ) + + self.safety_checker = None + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + + # Output_type must be 'pil' + sample = self.numpy_to_pil(sample) + return sample + + def prepare_class_labels(self, batch_size, device, class_labels=None): + if self.unet.config.num_class_embeds is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) + else: + class_labels = None + return class_labels + + def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + logger.warning( + f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" + " `timesteps` will be used over `num_inference_steps`." + ) + + if latents is not None: + expected_shape = (batch_size, 3, img_size, img_size) + if latents.shape != expected_shape: + raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + batch_size: int = 1, + class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, + num_inference_steps: int = 1, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): + Optional class labels for conditioning class-conditional consistency models. Will not be used if the + model is not class-conditional. + num_inference_steps (`int`, *optional*, defaults to 1): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Prepare call parameters + img_size = self.unet.config.sample_size + device = self._execution_device + + # 1. Check inputs + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) + + # 2. Prepare image latents + # Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=img_size, + width=img_size, + dtype=self.unet.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 3. Handle class_labels for class-conditional models + class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Denoising loop + # Multistep sampling: implements Algorithm 1 in the paper + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + scaled_sample = self.scheduler.scale_model_input(sample, t) + model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] + + sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] + + # call the callback, if provided + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + # 6. Post-process image sample + image = self.postprocess_image(sample, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/controlnet/__init__.py b/diffusers/pipelines/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd0b9a724e5e0a1a5cb864f0e5b8f698859f8d8a --- /dev/null +++ b/diffusers/pipelines/controlnet/__init__.py @@ -0,0 +1,27 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_invisible_watermark_available, + is_torch_available, + is_transformers_available, +) + + +if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + + +if is_transformers_available() and is_flax_available(): + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline diff --git a/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a54a503882d87a2f553efc35da9d78a8a6677698 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..975ff33fecdac637658cbe620de8a9fd2f835857 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..782bedb4fc75f18c0dca8c55e27f508d3d8783ce Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-310.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bd510e9041a69bffa495e17860f00830e362683 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-38.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b54423c3dbfbaadda20214188cde3881cb87e54c Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-310.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f519310705b76bb25f46427bda52ae2d46da70 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-38.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6618d4ab24658e0cb462ea9a7d30fbcb39245a46 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ba08d05a298ada1d5a75df37750ee494589635 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c82edc68f1680a4b1f86323ab98e5801c0f7d02e Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-310.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a001913f00c91cc7a5cd50189ee2cb6a028f6bc2 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-38.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c06e4f57bdec1752894cb63b55b82d8a1240548 Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc differ diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2de3a54fd7b28f2b735ff866d7ac38e39f75f3ff Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-38.pyc differ diff --git a/diffusers/pipelines/controlnet/multicontrolnet.py b/diffusers/pipelines/controlnet/multicontrolnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2df8c7edd7aee9b374fc687d225da4fa827b24a5 --- /dev/null +++ b/diffusers/pipelines/controlnet/multicontrolnet.py @@ -0,0 +1,185 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnet import ControlNetModel, ControlNetOutput +from ...models.modeling_utils import ModelMixin +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + idx = 0 + model_path_to_save = save_directory + for controlnet in self.nets: + controlnet.save_pretrained( + model_path_to_save, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + idx += 1 + model_path_to_save = model_path_to_save + f"_{idx}" + + @classmethod + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., + `./my_model_directory/controlnet`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + controlnets = [] + + # load controlnet and append to list until no controlnet directory exists anymore + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) + controlnets.append(controlnet) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") + + if len(controlnets) == 0: + raise ValueError( + f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(controlnets) diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet.py b/diffusers/pipelines/controlnet/pipeline_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcc0c67edf2b6e635de46821da519922cc48f17 --- /dev/null +++ b/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -0,0 +1,1012 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..c29a00a3542b6347a550a3c05d3776de19425273 --- /dev/null +++ b/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -0,0 +1,1105 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> np_image = np.array(image) + + >>> # get canny image + >>> np_image = cv2.Canny(np_image, 100, 200) + >>> np_image = np_image[:, :, None] + >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) + >>> canny_image = Image.fromarray(np_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... ).images[0] + ``` +""" + + +def prepare_image(image): + if isinstance(image, torch.Tensor): + # Batch single image + if image.ndim == 3: + image = image.unsqueeze(0) + + image = image.to(dtype=torch.float32) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + return image + + +class StableDiffusionControlNetImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.8, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accpet + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # 4. Prepare image + image = self.image_processor.preprocess(image).to(dtype=torch.float32) + + # 5. Prepare controlnet_conditioning_image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..b7481a0d4326c4560beafd3bde6631792de964d3 --- /dev/null +++ b/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -0,0 +1,1355 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image + + + >>> control_image = make_inpaint_condition(init_image, mask_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + + + This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as + [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) + as well as default text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on + those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4338946e6ff2972c38d9dc8ea01a404d77d987 --- /dev/null +++ b/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -0,0 +1,960 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput +from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # To be updated when there's a useful ControlNet checkpoint + >>> # compatible with SDXL. + ``` +""" + + +class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + raise ValueError("MultiControlNet is not yet supported.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.watermark = StableDiffusionXLWatermarker() + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 7.2 Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..872297605683485544cdb12217bf679d5223a56c --- /dev/null +++ b/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -0,0 +1,537 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ..stable_diffusion import FlaxStableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + + + >>> def image_grid(imgs, rows, cols): + ... w, h = imgs[0].size + ... grid = Image.new("RGB", size=(cols * w, rows * h)) + ... for i, img in enumerate(imgs): + ... grid.paste(img, box=(i % cols * w, i // cols * h)) + ... return grid + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> # get canny image + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) + + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" + + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 + ... ) + >>> params["controlnet"] = controlnet_params + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipe( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, + ... jit=True, + ... ).images + + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = image_grid(output_images, num_samples // 4, 4) + >>> output_images.save("generated_image.png") + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the unet during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_text_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + return processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + image = jnp.concatenate([image] * 2) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing the ControlNet input condition. ControlNet use this input condition to generate + guidance to Unet. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + + height, width = image.shape[-2:] + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.array(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image.convert("RGB") + w, h = image.size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return image diff --git a/diffusers/pipelines/dance_diffusion/__init__.py b/diffusers/pipelines/dance_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55d7f8ff9807083a10c844f7003cf0696d8258a3 --- /dev/null +++ b/diffusers/pipelines/dance_diffusion/__init__.py @@ -0,0 +1 @@ +from .pipeline_dance_diffusion import DanceDiffusionPipeline diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f42451c15f481fc23eed91a04e95cd1de62a0333 Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f7111675e18cc7da8d3e71f3ef2e1aaf806f51c Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a32766b9f30a474d2ab9b08bebcbb0a0c134e0 Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-38.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec664e0be7f909df39bd25a322e1f4bea0b7df71 Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c3eb32273b6da97ac807f3e4fbebb5edd53bb879 --- /dev/null +++ b/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -0,0 +1,125 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch + +from ...utils import logging, randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DanceDiffusionPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded audio. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio. Can be one of + [`IPNDMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 100, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + audio_length_in_s: Optional[float] = None, + return_dict: bool = True, + ) -> Union[AudioPipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of audio samples to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher-quality audio sample at + the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): + The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* + `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated audio. + """ + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate + + sample_size = audio_length_in_s * self.unet.config.sample_rate + + down_scale_factor = 2 ** len(self.unet.up_blocks) + if sample_size < 3 * down_scale_factor: + raise ValueError( + f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" + f" {3 * down_scale_factor / self.unet.config.sample_rate}." + ) + + original_sample_size = int(sample_size) + if sample_size % down_scale_factor != 0: + sample_size = ( + (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 + ) * down_scale_factor + logger.info( + f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" + " process." + ) + sample_size = int(sample_size) + + dtype = next(self.unet.parameters()).dtype + shape = (batch_size, self.unet.config.in_channels, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + audio = randn_tensor(shape, generator=generator, device=self._execution_device, dtype=dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, device=audio.device) + self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(audio, t).sample + + # 2. compute previous audio sample: x_t -> t_t-1 + audio = self.scheduler.step(model_output, t, audio).prev_sample + + audio = audio.clamp(-1, 1).float().cpu().numpy() + + audio = audio[:, :, :original_sample_size] + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/pipelines/ddim/__init__.py b/diffusers/pipelines/ddim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85e8118e75e7e4352f8efb12552ba9fff4bf491c --- /dev/null +++ b/diffusers/pipelines/ddim/__init__.py @@ -0,0 +1 @@ +from .pipeline_ddim import DDIMPipeline diff --git a/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045acf07f2783feaeb5484b3aba4dd0d3bb9fae4 Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..505a37aec1e246f5a08a8b40e0bf2c0caf600675 Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49ba3c57ea6d0ed4fe347fb014053d1fe9378766 Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc differ diff --git a/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410712086c01574b06e62b3eb124ee5bf0682230 Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc differ diff --git a/diffusers/pipelines/ddim/pipeline_ddim.py b/diffusers/pipelines/ddim/pipeline_ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..c24aa6c797935cd90ff9dad25a5ca5c07686216d --- /dev/null +++ b/diffusers/pipelines/ddim/pipeline_ddim.py @@ -0,0 +1,122 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +from ...schedulers import DDIMScheduler +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDIMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + + # make sure scheduler can always be converted to DDIM + scheduler = DDIMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + use_clipped_model_output: Optional[bool] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + use_clipped_model_output (`bool`, *optional*, defaults to `None`): + if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed + downstream to the scheduler. So use `None` for schedulers which don't support this argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step( + model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator + ).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/ddpm/__init__.py b/diffusers/pipelines/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb228ee012e80493b617b314c867ecadba7ca1ce --- /dev/null +++ b/diffusers/pipelines/ddpm/__init__.py @@ -0,0 +1 @@ +from .pipeline_ddpm import DDPMPipeline diff --git a/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72457c7e29b43d24a3d06663e01789182fcfe88c Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac1647bb5e9f229b9d4d3f55574b4c66c0f20ca2 Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96cf144e91518159b1b81690e3bce51b2e7addba Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc differ diff --git a/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71e5a15d92a2c283c002bc12dc95254761490069 Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc differ diff --git a/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/pipelines/ddpm/pipeline_ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..b4290daf852c2f3204a64b9955c9b53089d64bbc --- /dev/null +++ b/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -0,0 +1,105 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch + +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DDPMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 1000, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deepfloyd_if/__init__.py b/diffusers/pipelines/deepfloyd_if/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93414f20e7339a147ffa2d3dd36c871dfecda8e4 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/__init__.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL + +from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available +from .timesteps import ( + fast27_timesteps, + smart27_timesteps, + smart50_timesteps, + smart100_timesteps, + smart185_timesteps, + super27_timesteps, + super40_timesteps, + super100_timesteps, +) + + +@dataclass +class IFPipelineOutput(BaseOutput): + """ + Args: + Output class for Stable Diffusion pipelines. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content or a watermark. `None` if safety checking could not be performed. + watermark_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety + checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_detected: Optional[List[bool]] + watermark_detected: Optional[List[bool]] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_if import IFPipeline + from .pipeline_if_img2img import IFImg2ImgPipeline + from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline + from .pipeline_if_inpainting import IFInpaintingPipeline + from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline + from .pipeline_if_superresolution import IFSuperResolutionPipeline + from .safety_checker import IFSafetyChecker + from .watermark import IFWatermarker diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..679bb7a4b36455c1953311722e404c9b701d7238 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a41bbec97202b97f94fce6eb87225b9c10d8b6ee Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ddeb021863501dd9e09fbd29954bfae49186fa5 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b22bd7001afa5362757b8ccde63ed646e36512b Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79993dfdc9021f21f5388f581bb332fb3c38d78f Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce6c54ff9b838c60b85504809d336a3464516a48 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba91903c44975c3df3a288d8d7b36087bd052680 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674f4d2500046898a6e2433da2762e7321be3298 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e92b7ae88e9386fdd0be786acac33c3108e7ff93 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203f15a76f7b317903f2d4c4e4426abb1163f5ec Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d565672b37f5dfca573826d524aa67db5030fe09 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5dd8f295bb53d067f27f64f6098dcf3a469571c Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2917ffa56dd37d2bc6040d6805cab4ab7a9631d6 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c579f7ab530c509c3c6b9e2bb2a320d0a35bd239 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25f33752dc5dc47c56e105355c03e5ab1537bf8c Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c09292ed709ee00f40af3eec29072a77d5494269 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31a5962c75f4bed23e9853698aa73a9fd9ce1973 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29c3af9e1597507385b5f99d1e07ba291d4f2144 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbf06ca889ba0615e4085ee0420e9aebb7662a1e Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35450ba13ad33d8c87985240ba7b3b388c1fe9d4 Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-38.pyc differ diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/diffusers/pipelines/deepfloyd_if/pipeline_if.py new file mode 100644 index 0000000000000000000000000000000000000000..ffcb1ab32d357dd0c546bd96def75207752e06cb --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -0,0 +1,816 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> safety_modules = { + ... "feature_extractor": pipe.feature_extractor, + ... "safety_checker": pipe.safety_checker, + ... "watermarker": pipe.watermarker, + ... } + >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 + ... ) + >>> super_res_2_pipe.enable_model_cpu_offload() + + >>> image = super_res_2_pipe( + ... prompt=prompt, + ... image=image, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + self.unet.config.in_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4c5229e1c48b4a561f24d88a776007cdf22f1a --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -0,0 +1,940 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.7, + num_inference_steps: int = 80, + timesteps: List[int] = None, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea0bafa51a011d4c1741c46090aaae55b1e06e5 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -0,0 +1,1058 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warn( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], + original_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.8, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..b11bf780de472e92325b749e516bdc7883c22297 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -0,0 +1,1059 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + mask_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..2971570aada23ca1f88d212b7f82eb6e9c2118eb --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -0,0 +1,1169 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` + """ + + +class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warn( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], + original_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + mask_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.8, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 0): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. prepare mask image + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + mask_image, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdd01fe748e411c032dbb5316c6761ecd5716e5 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -0,0 +1,914 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warn( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def preprocess_image(self, image, num_images_per_prompt, device): + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: int = None, + width: int = None, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): + The image to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + num_channels = self.unet.config.in_channels // 2 + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + num_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare upscaled image and noise level + image = self.preprocess_image(image, num_images_per_prompt, device) + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 11. Convert to PIL + image = self.numpy_to_pil(image) + + # 12. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/diffusers/pipelines/deepfloyd_if/safety_checker.py b/diffusers/pipelines/deepfloyd_if/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffeed580bbea1514b11bf7a168a952328d8f424 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/safety_checker.py @@ -0,0 +1,59 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class IFSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModelWithProjection(config.vision_config) + + self.p_head = nn.Linear(config.vision_config.projection_dim, 1) + self.w_head = nn.Linear(config.vision_config.projection_dim, 1) + + @torch.no_grad() + def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): + image_embeds = self.vision_model(clip_input)[0] + + nsfw_detected = self.p_head(image_embeds) + nsfw_detected = nsfw_detected.flatten() + nsfw_detected = nsfw_detected > p_threshold + nsfw_detected = nsfw_detected.tolist() + + if any(nsfw_detected): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, nsfw_detected_ in enumerate(nsfw_detected): + if nsfw_detected_: + images[idx] = np.zeros(images[idx].shape) + + watermark_detected = self.w_head(image_embeds) + watermark_detected = watermark_detected.flatten() + watermark_detected = watermark_detected > w_threshold + watermark_detected = watermark_detected.tolist() + + if any(watermark_detected): + logger.warning( + "Potential watermarked content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, watermark_detected_ in enumerate(watermark_detected): + if watermark_detected_: + images[idx] = np.zeros(images[idx].shape) + + return images, nsfw_detected, watermark_detected diff --git a/diffusers/pipelines/deepfloyd_if/timesteps.py b/diffusers/pipelines/deepfloyd_if/timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..d44285c017bbb2ccffa4ae86dd77792a048625d9 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/timesteps.py @@ -0,0 +1,579 @@ +fast27_timesteps = [ + 999, + 800, + 799, + 600, + 599, + 500, + 400, + 399, + 377, + 355, + 333, + 311, + 288, + 266, + 244, + 222, + 200, + 199, + 177, + 155, + 133, + 111, + 88, + 66, + 44, + 22, + 0, +] + +smart27_timesteps = [ + 999, + 976, + 952, + 928, + 905, + 882, + 858, + 857, + 810, + 762, + 715, + 714, + 572, + 429, + 428, + 286, + 285, + 238, + 190, + 143, + 142, + 118, + 95, + 71, + 47, + 24, + 0, +] + +smart50_timesteps = [ + 999, + 988, + 977, + 966, + 955, + 944, + 933, + 922, + 911, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 350, + 300, + 299, + 266, + 233, + 200, + 199, + 179, + 159, + 140, + 120, + 100, + 99, + 88, + 77, + 66, + 55, + 44, + 33, + 22, + 11, + 0, +] + +smart100_timesteps = [ + 999, + 995, + 992, + 989, + 985, + 981, + 978, + 975, + 971, + 967, + 964, + 961, + 957, + 956, + 951, + 947, + 942, + 937, + 933, + 928, + 923, + 919, + 914, + 913, + 908, + 903, + 897, + 892, + 887, + 881, + 876, + 871, + 870, + 864, + 858, + 852, + 846, + 840, + 834, + 828, + 827, + 820, + 813, + 806, + 799, + 792, + 785, + 784, + 777, + 770, + 763, + 756, + 749, + 742, + 741, + 733, + 724, + 716, + 707, + 699, + 698, + 688, + 677, + 666, + 656, + 655, + 645, + 634, + 623, + 613, + 612, + 598, + 584, + 570, + 569, + 555, + 541, + 527, + 526, + 505, + 484, + 483, + 462, + 440, + 439, + 396, + 395, + 352, + 351, + 308, + 307, + 264, + 263, + 220, + 219, + 176, + 132, + 88, + 44, + 0, +] + +smart185_timesteps = [ + 999, + 997, + 995, + 992, + 990, + 988, + 986, + 984, + 981, + 979, + 977, + 975, + 972, + 970, + 968, + 966, + 964, + 961, + 959, + 957, + 956, + 954, + 951, + 949, + 946, + 944, + 941, + 939, + 936, + 934, + 931, + 929, + 926, + 924, + 921, + 919, + 916, + 914, + 913, + 910, + 907, + 905, + 902, + 899, + 896, + 893, + 891, + 888, + 885, + 882, + 879, + 877, + 874, + 871, + 870, + 867, + 864, + 861, + 858, + 855, + 852, + 849, + 846, + 843, + 840, + 837, + 834, + 831, + 828, + 827, + 824, + 821, + 817, + 814, + 811, + 808, + 804, + 801, + 798, + 795, + 791, + 788, + 785, + 784, + 780, + 777, + 774, + 770, + 766, + 763, + 760, + 756, + 752, + 749, + 746, + 742, + 741, + 737, + 733, + 730, + 726, + 722, + 718, + 714, + 710, + 707, + 703, + 699, + 698, + 694, + 690, + 685, + 681, + 677, + 673, + 669, + 664, + 660, + 656, + 655, + 650, + 646, + 641, + 636, + 632, + 627, + 622, + 618, + 613, + 612, + 607, + 602, + 596, + 591, + 586, + 580, + 575, + 570, + 569, + 563, + 557, + 551, + 545, + 539, + 533, + 527, + 526, + 519, + 512, + 505, + 498, + 491, + 484, + 483, + 474, + 466, + 457, + 449, + 440, + 439, + 428, + 418, + 407, + 396, + 395, + 381, + 366, + 352, + 351, + 330, + 308, + 307, + 286, + 264, + 263, + 242, + 220, + 219, + 176, + 175, + 132, + 131, + 88, + 44, + 0, +] + +super27_timesteps = [ + 999, + 991, + 982, + 974, + 966, + 958, + 950, + 941, + 933, + 925, + 916, + 908, + 900, + 899, + 874, + 850, + 825, + 800, + 799, + 700, + 600, + 500, + 400, + 300, + 200, + 100, + 0, +] + +super40_timesteps = [ + 999, + 992, + 985, + 978, + 971, + 964, + 957, + 949, + 942, + 935, + 928, + 921, + 914, + 907, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 300, + 299, + 200, + 199, + 100, + 99, + 0, +] + +super100_timesteps = [ + 999, + 996, + 992, + 989, + 985, + 982, + 979, + 975, + 972, + 968, + 965, + 961, + 958, + 955, + 951, + 948, + 944, + 941, + 938, + 934, + 931, + 927, + 924, + 920, + 917, + 914, + 910, + 907, + 903, + 900, + 899, + 891, + 884, + 876, + 869, + 861, + 853, + 846, + 838, + 830, + 823, + 815, + 808, + 800, + 799, + 788, + 777, + 766, + 755, + 744, + 733, + 722, + 711, + 700, + 699, + 688, + 677, + 666, + 655, + 644, + 633, + 622, + 611, + 600, + 599, + 585, + 571, + 557, + 542, + 528, + 514, + 500, + 499, + 485, + 471, + 457, + 442, + 428, + 414, + 400, + 399, + 379, + 359, + 340, + 320, + 300, + 299, + 279, + 259, + 240, + 220, + 200, + 199, + 166, + 133, + 100, + 99, + 66, + 33, + 0, +] diff --git a/diffusers/pipelines/deepfloyd_if/watermark.py b/diffusers/pipelines/deepfloyd_if/watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..db33dec0ef9ad5909e79358e9d89bdc0ed9c9909 --- /dev/null +++ b/diffusers/pipelines/deepfloyd_if/watermark.py @@ -0,0 +1,46 @@ +from typing import List + +import PIL +import torch +from PIL import Image + +from ...configuration_utils import ConfigMixin +from ...models.modeling_utils import ModelMixin +from ...utils import PIL_INTERPOLATION + + +class IFWatermarker(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) + self.watermark_image_as_pil = None + + def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): + # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 + + h = images[0].height + w = images[0].width + + sample_size = sample_size or h + + coef = min(h / sample_size, w / sample_size) + img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) + + S1, S2 = 1024**2, img_w * img_h + K = (S2 / S1) ** 0.5 + wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) + + if self.watermark_image_as_pil is None: + watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() + watermark_image = Image.fromarray(watermark_image, mode="RGBA") + self.watermark_image_as_pil = watermark_image + + wm_img = self.watermark_image_as_pil.resize( + (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None + ) + + for pil_img in images: + pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) + + return images diff --git a/diffusers/pipelines/dit/__init__.py b/diffusers/pipelines/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef0729cb4905d5e177ba15533375fce50084406 --- /dev/null +++ b/diffusers/pipelines/dit/__init__.py @@ -0,0 +1 @@ +from .pipeline_dit import DiTPipeline diff --git a/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32dd16165b7eaa736a9c50f18b4662a9537b4b4b Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/dit/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/dit/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19eaacc0386c7f7336f2fa59cf5a83c099094217 Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..886af5f803953a880efeb8d02562b834e8a00fa5 Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc differ diff --git a/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-38.pyc b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0906480834d5f9ccc32ddaa0fca641a56dab659 Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-38.pyc differ diff --git a/diffusers/pipelines/dit/pipeline_dit.py b/diffusers/pipelines/dit/pipeline_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..07fd2835ccf0f2f6d485b45b398dd25049c75228 --- /dev/null +++ b/diffusers/pipelines/dit/pipeline_dit.py @@ -0,0 +1,199 @@ +# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +# William Peebles and Saining Xie +# +# Copyright (c) 2021 OpenAI +# MIT License +# +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from ...models import AutoencoderKL, Transformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DiTPipeline(DiffusionPipeline): + r""" + This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + transformer ([`Transformer2DModel`]): + Class conditioned Transformer in Diffusion model to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `dit` to denoise the encoded image latents. + """ + + def __init__( + self, + transformer: Transformer2DModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + id2label: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + # create a imagenet -> id dictionary for easier use + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.lstrip().rstrip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + + Map label strings, *e.g.* from ImageNet, to corresponding class ids. + + Parameters: + label (`str` or `dict` of `str`): label strings to be mapped to class ids. + + Returns: + `list` of `int`: Class ids to be processed by pipeline. + """ + + if not isinstance(label, list): + label = list(label) + + for l in label: + if l not in self.labels: + raise ValueError( + f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[l] for l in label] + + @torch.no_grad() + def __call__( + self, + class_labels: List[int], + guidance_scale: float = 4.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + class_labels (List[int]): + List of imagenet class labels for the images to be generated. + guidance_scale (`float`, *optional*, defaults to 4.0): + Scale of the guidance signal. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 250): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + """ + + batch_size = len(class_labels) + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1) + class_null = torch.tensor([1000] * batch_size, device=self._execution_device) + class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1: + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = torch.cat([half, half], dim=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timesteps = t + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(latent_model_input.shape[0]) + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ).sample + + # perform guidance + if guidance_scale > 1: + eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + + noise_pred = torch.cat([eps, rest], dim=1) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + model_output, _ = torch.split(noise_pred, latent_channels, dim=1) + else: + model_output = noise_pred + + # compute previous image: x_t -> x_t-1 + latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + + if guidance_scale > 1: + latents, _ = latent_model_input.chunk(2, dim=0) + else: + latents = latent_model_input + + latents = 1 / self.vae.config.scaling_factor * latents + samples = self.vae.decode(latents).sample + + samples = (samples / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + samples = self.numpy_to_pil(samples) + + if not return_dict: + return (samples,) + + return ImagePipelineOutput(images=samples) diff --git a/diffusers/pipelines/kandinsky/__init__.py b/diffusers/pipelines/kandinsky/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..242ff799e529abbb268b3562a9671db42d9de37e --- /dev/null +++ b/diffusers/pipelines/kandinsky/__init__.py @@ -0,0 +1,19 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import KandinskyPipeline, KandinskyPriorPipeline +else: + from .pipeline_kandinsky import KandinskyPipeline + from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline + from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline + from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput + from .text_encoder import MultilingualCLIP diff --git a/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a8cc27936214128b0d2d8d9f9a5921c83b2fe6a Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..002c546c28e7144a22e81d1eaa675ae046357498 Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09a72cfa0543e44fbd5a4d86cf16504df16ee8bc Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e3b2501783290c5c6f97761f8ad59b1e6bbfb5e Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa36860a2e14dec3318d43da49259d1c78e97fd3 Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e938f6793940f0e60bce0455cfc83d9c2a86337 Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55005faafefbef98e84fe66a19dfa821f5994da9 Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a24732a35092fb398e6b4a081000fc9c5f338e6 Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e43c43de80edbbca5374f14a68dcdc0de74e0bf Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb762bca6343c0b550e5643a6bda4d2dbcc37fcc Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5286fe1e4b74bb3a1623a4ee199c158bd692fe Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bcae69dab418b390b94f5e5ddd72890506a6155 Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky.py new file mode 100644 index 0000000000000000000000000000000000000000..31edb915e8f76df6910a4a336286cc0b93df8722 --- /dev/null +++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -0,0 +1,418 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDIMScheduler, DDPMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +class KandinskyPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + movq: VQModel, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=77, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e4b04febacc5c751f59c37f6102268f497166d --- /dev/null +++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -0,0 +1,504 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDIMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... prompt, + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + + shape = latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + latents = self.add_noise(latents, noise, latent_timestep) + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + + return noisy_samples + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], + image_embeds: torch.FloatTensor, + negative_image_embeds: torch.FloatTensor, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + strength: float = 0.3, + guidance_scale: float = 7.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor`, `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. get text and image embeddings + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # 3. pre-processing initial image + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + # the formular to calculate timestep for add_noise is taken from the original kandinsky repo + latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 + + latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # 5. Create initial latent + latents = self.prepare_latents( + latents, + latent_timestep, + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + self.scheduler, + ) + + # 6. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + # 7. post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a7dc225947d5f70ad05c13886c79c4b34ee70b --- /dev/null +++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -0,0 +1,630 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDIMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyInpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.ones((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 0 + + >>> out = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + return mask, image + + +class KandinskyInpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + movq=movq, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + image_embeds: torch.FloatTensor, + negative_image_embeds: torch.FloatTensor, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor`, `PIL.Image.Image` or `np.ndarray`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`,`torch.FloatTensor` or `np.ndarray`): + `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the + image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the + expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL + image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it + will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected + shape is `(H, W)`. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + # Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=prompt_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + # get h, w for latents + sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, sample_height, sample_width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # Check that sizes of mask, masked image and latents match with expected + num_channels_mask = mask_image.shape[1] + num_channels_masked_image = masked_image.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..4d99180b3d1106b553b69529e7ede3bfdef4b65f --- /dev/null +++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -0,0 +1,541 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import UnCLIPScheduler +from ...utils import ( + BaseOutput, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) + >>> pipe.to("cuda") + + >>> image = pipe( + ... "", + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +@dataclass +class KandinskyPriorPipelineOutput(BaseOutput): + """ + Output class for KandinskyPriorPipeline. + + Args: + image_embeds (`torch.FloatTensor`) + clip image embeddings for text prompt + negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) + clip image embeddings for unconditional tokens + """ + + image_embeds: Union[torch.FloatTensor, np.ndarray] + negative_image_embeds: Union[torch.FloatTensor, np.ndarray] + + +class KandinskyPriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _exclude_from_cpu_offload = ["prior"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: Union[str] = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"] + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/diffusers/pipelines/kandinsky/text_encoder.py b/diffusers/pipelines/kandinsky/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..caa0029f00ca22818819d5b76b57ec489c6da1d6 --- /dev/null +++ b/diffusers/pipelines/kandinsky/text_encoder.py @@ -0,0 +1,27 @@ +import torch +from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel + + +class MCLIPConfig(XLMRobertaConfig): + model_type = "M-CLIP" + + def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): + self.transformerDimensions = transformerDimSize + self.numDims = imageDimSize + super().__init__(**kwargs) + + +class MultilingualCLIP(PreTrainedModel): + config_class = MCLIPConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.transformer = XLMRobertaModel(config) + self.LinearTransformation = torch.nn.Linear( + in_features=config.transformerDimensions, out_features=config.numDims + ) + + def forward(self, input_ids, attention_mask): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] + return self.LinearTransformation(embs2), embs diff --git a/diffusers/pipelines/kandinsky2_2/__init__.py b/diffusers/pipelines/kandinsky2_2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..648164b9f1ba657feb686a70ad2a4e367f898e20 --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/__init__.py @@ -0,0 +1,7 @@ +from .pipeline_kandinsky2_2 import KandinskyV22Pipeline +from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline +from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline +from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline +from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline +from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline +from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeb8fdfef5461688b1e4682c7361726a12e39c75 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f026433b3fcc6784da85c9aefe92f4f9b95aab50 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f21ed28416d19815453a4d783d82e5b53cd8399 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..483d98bd87ff8442a8b8d56a95530b4706d61b9d Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a47f9c194095fb7d56d7406400ce639a50e8099 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8164953faae83a22c730c0a3960502adc358bf22 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98fc4f8679d66baa34b966668bf2432cfee2e1ca Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..078d473864639d099f0c996a31be139ce1860757 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a294348529832329145fdab092cd925a3ace94 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b658ffa0c19bc29033b431faed31fa3149d93346 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709546e6617e18d1969c619c0a9653aae74df741 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0869b29f128d52755b38f7d98203a642da4ff9d Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..215d89367edda2d601230003afee72f70c378c2f Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45add2c13fab29e27f453d8d3b6f4d9b1df58b55 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..057c72e4a00e4af38ff82870c1d9a30f049dd401 Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-310.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..072ff2d46a0eb296655632b1184485c351ca57ae Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-38.pyc differ diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py new file mode 100644 index 0000000000000000000000000000000000000000..6db5260b04d82d7301d887dc3095b7b64aff0ab5 --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -0,0 +1,277 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") + >>> pipe_prior.to("cuda") + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> zero_image_emb = out.negative_image_embeds + >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images + >>> image[0].save("cat.png") + ``` +""" + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class KandinskyV22Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Args: + Function invoked when calling the pipeline for generation. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] * num_images_per_prompt + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6fc269117470dd0130b3546ece0f509e936126e --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py @@ -0,0 +1,331 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import numpy as np + + >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline + >>> from transformers import pipeline + >>> from diffusers.utils import load_image + + + >>> def make_hint(image, depth_estimator): + ... image = depth_estimator(image)["depth"] + ... image = np.array(image) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... detected_map = torch.from_numpy(image).float() / 255.0 + ... hint = detected_map.permute(2, 0, 1) + ... return hint + + + >>> depth_estimator = pipeline("depth-estimation") + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior = pipe_prior.to("cuda") + + >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((768, 768)) + + >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") + + >>> prompt = "A robot, 4k photo" + >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" + + >>> generator = torch.Generator(device="cuda").manual_seed(43) + + >>> image_emb, zero_image_emb = pipe_prior( + ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator + ... ).to_tuple() + + >>> images = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... hint=hint, + ... num_inference_steps=50, + ... generator=generator, + ... height=768, + ... width=768, + ... ).images + + >>> images[0].save("robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class KandinskyV22ControlnetPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + hint: torch.FloatTensor, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + hint (`torch.FloatTensor`): + The controlnet condition. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + if isinstance(hint, list): + hint = torch.cat(hint, dim=0) + + batch_size = image_embeds.shape[0] * num_images_per_prompt + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hint = hint.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) + hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..8a25624b7267bd3f8d82fa7fc6514920193adfc0 --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py @@ -0,0 +1,393 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from PIL import Image + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import numpy as np + + >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline + >>> from transformers import pipeline + >>> from diffusers.utils import load_image + + + >>> def make_hint(image, depth_estimator): + ... image = depth_estimator(image)["depth"] + ... image = np.array(image) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... detected_map = torch.from_numpy(image).float() / 255.0 + ... hint = detected_map.permute(2, 0, 1) + ... return hint + + + >>> depth_estimator = pipeline("depth-estimation") + + >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior = pipe_prior.to("cuda") + + >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((768, 768)) + + + >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") + + >>> prompt = "A robot, 4k photo" + >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" + + >>> generator = torch.Generator(device="cuda").manual_seed(43) + + >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator) + >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator) + + >>> images = pipe( + ... image=img, + ... strength=0.5, + ... image_embeds=img_emb.image_embeds, + ... negative_image_embeds=negative_emb.image_embeds, + ... hint=hint, + ... num_inference_steps=50, + ... generator=generator, + ... height=768, + ... width=768, + ... ).images + + >>> images[0].save("robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + hint: torch.FloatTensor, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + hint (`torch.FloatTensor`): + The controlnet condition. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + if isinstance(hint, list): + hint = torch.cat(hint, dim=0) + + batch_size = image_embeds.shape[0] + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hint = hint.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) + hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) + + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=image_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator + ) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..26976ad0c925f09f2f550741c4c7a26f83c68700 --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py @@ -0,0 +1,357 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from PIL import Image + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyV22Img2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) + + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=image_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator + ) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..27bd01984dbf2bfb6d1bf3d848dc998c87d3077a --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py @@ -0,0 +1,490 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from PIL import Image + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyV22InpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.ones((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 0 + + >>> out = pipe( + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + return mask, image + + +class KandinskyV22InpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Args: + Function invoked when calling the pipeline for generation. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. Black pixels in the mask will be repainted, while + white pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] * num_images_per_prompt + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=image_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=image_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + num_channels_latents = self.movq.config.latent_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + noise = torch.clone(latents) + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + init_latents_proper = image[:1] + init_mask = mask_image[:1] + + if i < len(timesteps_tensor) - 1: + noise_timestep = timesteps_tensor[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + # post-processing + latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..1bfc6523cdf944a0cbc31e159eb0199032b53216 --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -0,0 +1,501 @@ +from typing import List, Optional, Union + +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import UnCLIPScheduler +from ...utils import ( + logging, + randn_tensor, + replace_example_docstring, +) +from ..kandinsky import KandinskyPriorPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") + >>> pipe_prior.to("cuda") + >>> prompt = "red cat, 4k photo" + >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple() + + >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline + >>> from diffusers.utils import load_image + >>> import PIL + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> out = pipe_prior.interpolate(images_texts, weights) + >>> pipe = KandinskyV22Pipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=out.image_embeds, + ... negative_image_embeds=out.negative_image_embeds, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images[0] + >>> image.save("starry_cat.png") + ``` +""" + + +class KandinskyV22PriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + _exclude_from_cpu_offload = ["prior"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: Union[str] = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds.unsqueeze(0) + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0) + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py new file mode 100644 index 0000000000000000000000000000000000000000..bcbeb78bac1020c29e3c1c22dac4dc455714b9ab --- /dev/null +++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py @@ -0,0 +1,565 @@ +from typing import List, Optional, Union + +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import UnCLIPScheduler +from ...utils import ( + logging, + randn_tensor, + replace_example_docstring, +) +from ..kandinsky import KandinskyPriorPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() + + >>> pipe = KandinskyPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16" + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyV22Pipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _exclude_from_cpu_offload = ["prior"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: Union[str] = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds.unsqueeze(0) + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + image_emb = self._encode_image( + cond, device=device, num_images_per_prompt=num_images_per_prompt + ).unsqueeze(0) + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0) + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb)) + + def _encode_image( + self, + image: Union[torch.Tensor, List[PIL.Image.Image]], + device, + num_images_per_prompt, + ): + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values.to( + dtype=self.image_encoder.dtype, device=device + ) + + image_emb = self.image_encoder(image)["image_embeds"] # B, D + image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0) + image_emb.to(device=device) + + return image_emb + + def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + emb = emb.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + init_latents = emb + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], + strength: float = 0.3, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. + emb (`torch.FloatTensor`): + The image embedding. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if not isinstance(image, List): + image = [image] + + if isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + if isinstance(image, torch.Tensor) and image.ndim == 2: + # allow user to pass image_embeds directly + image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0) + elif isinstance(image, torch.Tensor) and image.ndim != 4: + raise ValueError( + f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}" + ) + else: + image_embeds = self._encode_image(image, device, num_images_per_prompt) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + + latents = image_embeds + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size) + latents = self.prepare_latents( + latents, + latent_timestep, + batch_size // num_images_per_prompt, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == timesteps.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/pipelines/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cce9a89bcbeaac8468d75e9d16c9d3731f738c7 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,6 @@ +from ...utils import is_transformers_available +from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline + + +if is_transformers_available(): + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e25ca759f420cc63716e58b9b12a507f6073b84a Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf103f1700b6b5f5aabf86a2f04fc7a45c61432 Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ebabe73d24df451cd1d583965213afde489396f Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df45fb238fc00cfa5cc10dc9a5e168b1cc461e1e Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..257cd82a2c07c2617929ba61c9be2e1469200506 Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc differ diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-38.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a5d2591784cba978dab74cac7a11b867dfef006 Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-38.pyc differ diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7c28b96cc85724d925928ba5b065a9ba2bf095 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,726 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.utils import logging + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: Union[VQModel, AutoencoderKL], + bert: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: Union[UNet2DModel, UNet2DConditionModel], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at + the, usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" + ) + negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self._execution_device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") + prompt_embeds = self.bert(text_input.input_ids.to(self._execution_device))[0] + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + latents_shape, generator=generator, device=self._execution_device, dtype=prompt_embeds.dtype + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self._execution_device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = prompt_embeds + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / self.vqvae.config.scaling_factor * latents + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + _supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + _no_split_modules = [] + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..ae620d325307605fa08fa977b9865dfc9adff057 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,159 @@ +import inspect +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.utils.checkpoint + +from ...models import UNet2DModel, VQModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using Latent + + This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: Union[torch.Tensor, PIL.Image.Image] = None, + batch_size: Optional[int] = 1, + num_inference_steps: Optional[int] = 100, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + height, width = image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) + latents_dtype = next(self.unet.parameters()).dtype + + latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + image = image.to(device=self.device, dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = torch.cat([latents, image], dim=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents).sample + image = torch.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/latent_diffusion_uncond/__init__.py b/diffusers/pipelines/latent_diffusion_uncond/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9fc5270a62bbb18d1393263101d4b9f73b7511 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion_uncond/__init__.py @@ -0,0 +1 @@ +from .pipeline_latent_diffusion_uncond import LDMPipeline diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1217a15beaedd74c3fa3d6918e7e082a15359791 Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a2aa4da26c8cb79661cd3a313a5e4fdec44a88 Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..865be1756ab40c346bfaba993d4253ca0e8dee4a Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc differ diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f53e4259755891e426adcc0de9eabea7ce713c5 Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc differ diff --git a/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py new file mode 100644 index 0000000000000000000000000000000000000000..73c607a27187eb93a55570a825a4beee329a256c --- /dev/null +++ b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -0,0 +1,111 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel, VQModel +from ...schedulers import DDIMScheduler +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class LDMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. + """ + + def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + latents = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + ) + latents = latents.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + latent_model_input = self.scheduler.scale_model_input(latents, t) + # predict the noise residual + noise_prediction = self.unet(latent_model_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VAE + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/onnx_utils.py b/diffusers/pipelines/onnx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07c32e4e84bfee0241733a077fef9c0dec06905e --- /dev/null +++ b/diffusers/pipelines/onnx_utils.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional, Union + +import numpy as np +from huggingface_hub import hf_hub_download + +from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging + + +if is_onnx_available(): + import onnxruntime as ort + + +logger = logging.get_logger(__name__) + +ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, +} + + +class OnnxRuntimeModel: + def __init__(self, model=None, **kwargs): + logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") + self.model = model + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) + + def __call__(self, **kwargs): + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.model.run(None, inputs) + + @staticmethod + def load_model(path: Union[str, Path], provider=None, sess_options=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the + latest_model_name. + + Arguments: + save_directory (`str` or `Path`): + Directory where to save the model file. + file_name(`str`, *optional*): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the + model with a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + try: + shutil.copyfile(src_path, dst_path) + except shutil.SameFileError: + pass + + # copy external weights (for models >2GB) + src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + if src_path.exists(): + dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) + try: + shutil.copyfile(src_path, dst_path) + except shutil.SameFileError: + pass + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + """ + Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class + method.: + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["ort.SessionOptions"] = None, + **kwargs, + ): + """ + Load a model from a directory or the HF Hub. + + Arguments: + model_id (`str` or `Path`): + Directory from which to load + use_auth_token (`str` or `bool`): + Is needed to load models from a private or gated repository + revision (`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (`Union[str, Path]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load + different model files from the same repository or directory. + provider(`str`): + The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. + kwargs (`Dict`, *optional*): + kwargs will be passed to the model during initialization + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + model = OnnxRuntimeModel.load_model( + os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options + ) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) + return cls(model=model, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + force_download: bool = True, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) diff --git a/diffusers/pipelines/paint_by_example/__init__.py b/diffusers/pipelines/paint_by_example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fc8cb71e3f4e1e8baf16c7143658ca64934306 --- /dev/null +++ b/diffusers/pipelines/paint_by_example/__init__.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .image_encoder import PaintByExampleImageEncoder + from .pipeline_paint_by_example import PaintByExamplePipeline diff --git a/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeffcf96e39a190fa9dab2909c06f6320357979c Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f6db4581be9a78ecad6dcfd633de745e765deca Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5304bfe2805f636cb4fa7bfe7641d9628524ef8 Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc differ diff --git a/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-38.pyc b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..637ac265a63419ee946d5d364dc2de8ecdce915a Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-38.pyc differ diff --git a/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c421b21dd0d17647a7ac44f6b5113bd9f9058ae Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc differ diff --git a/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-38.pyc b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24484935753f0629b253fa553c6e423902a879ae Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-38.pyc differ diff --git a/diffusers/pipelines/paint_by_example/image_encoder.py b/diffusers/pipelines/paint_by_example/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..831489eefed167264c8fd8f57e1ed59610ebb858 --- /dev/null +++ b/diffusers/pipelines/paint_by_example/image_encoder.py @@ -0,0 +1,67 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import CLIPPreTrainedModel, CLIPVisionModel + +from ...models.attention import BasicTransformerBlock +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PaintByExampleImageEncoder(CLIPPreTrainedModel): + def __init__(self, config, proj_size=768): + super().__init__(config) + self.proj_size = proj_size + + self.model = CLIPVisionModel(config) + self.mapper = PaintByExampleMapper(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.proj_out = nn.Linear(config.hidden_size, self.proj_size) + + # uncondition for scaling + self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) + + def forward(self, pixel_values, return_uncond_vector=False): + clip_output = self.model(pixel_values=pixel_values) + latent_states = clip_output.pooler_output + latent_states = self.mapper(latent_states[:, None]) + latent_states = self.final_layer_norm(latent_states) + latent_states = self.proj_out(latent_states) + if return_uncond_vector: + return latent_states, self.uncond_vector + + return latent_states + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + + return hidden_states diff --git a/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..f844834b527d4cffcffaa589df22c633e358074e --- /dev/null +++ b/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -0,0 +1,557 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .image_encoder import PaintByExampleImageEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Batched mask + if mask.shape[0] == image.shape[0]: + mask = mask.unsqueeze(1) + else: + mask = mask.unsqueeze(0) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert mask.shape[1] == 1, "Mask image must have a single channel" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # paint-by-example inverses the mask + mask = 1 - mask + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = [image] + + image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, PIL.Image.Image): + mask = [mask] + + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + + # paint-by-example inverses the mask + mask = 1 - mask + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * mask + + return mask, masked_image + + +class PaintByExamplePipeline(DiffusionPipeline): + r""" + Pipeline for image-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`PaintByExampleImageEncoder`]): + Encodes the example input image. The unet is conditioned on the example image instead of a text prompt. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + # TODO: feature_extractor is required to encode initial images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: PaintByExampleImageEncoder, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + example_image: Union[torch.FloatTensor, PIL.Image.Image], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + example_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + The exemplar image to guide the image generation. + image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + height, width = masked_image.shape[-2:] + + # 3. Check inputs + self.check_inputs(example_image, height, width, callback_steps) + + # 4. Encode input image + image_embeddings = self._encode_image( + example_image, device, num_images_per_prompt, do_classifier_free_guidance + ) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + image_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pipeline_flax_utils.py b/diffusers/pipelines/pipeline_flax_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c4b9f53953c1f18e3cc90088dfddd612cbfa63 --- /dev/null +++ b/diffusers/pipelines/pipeline_flax_utils.py @@ -0,0 +1,568 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import Any, Dict, List, Optional, Union + +import flax +import numpy as np +import PIL +from flax.core.frozen_dict import FrozenDict +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from ..configuration_utils import ConfigMixin +from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin +from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging + + +if is_transformers_available(): + from transformers import FlaxPreTrainedModel + +INDEX_FILE = "diffusion_flax_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "FlaxModelMixin": ["save_pretrained", "from_pretrained"], + "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], + "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +def import_flax_or_no_model(module, class_name): + try: + # 1. First make sure that if a Flax object is present, import this one + class_obj = getattr(module, "Flax" + class_name) + except AttributeError: + # 2. If this doesn't work, it's not a model and we don't append "Flax" + class_obj = getattr(module, class_name) + except AttributeError: + raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") + + return class_obj + + +@flax.struct.dataclass +class FlaxImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class FlaxDiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion + pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all + pipelines to: + + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + if module is None: + register_dict = {name: (None, None)} + else: + # retrieve library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): + # TODO: handle inference_state + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class + method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue + + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) + + if expects_params: + save_method( + os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a Flax diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + dtype (`str` or `jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import FlaxDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> # Requires to be logged in to Hugging Face hub, + >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", + ... revision="bf16", + ... dtype=jnp.bfloat16, + ... ) + + >>> # Download pipeline, but use a different scheduler + >>> from diffusers import FlaxDPMSolverMultistepScheduler + + >>> model_id = "runwayml/stable-diffusion-v1-5" + >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... ) + + >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( + ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp + ... ) + >>> dpm_params["scheduler"] = dpmpp_state + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_pt = kwargs.pop("from_pt", False) + use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) + dtype = kwargs.pop("dtype", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + + # make sure we don't download PyTorch weights, unless when using from_pt + ignore_patterns = "*.bin" if not from_pt else [] + + if cls != FlaxDiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + requested_pipeline_class = ( + requested_pipeline_class + if requested_pipeline_class.startswith("Flax") + else "Flax" + requested_pipeline_class + ) + + user_agent = {"pipeline_class": requested_pipeline_class} + user_agent = http_user_agent(user_agent) + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != FlaxDiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + class_name = ( + config_dict["_class_name"] + if config_dict["_class_name"].startswith("Flax") + else "Flax" + config_dict["_class_name"] + ) + pipeline_class = getattr(diffusers_module, class_name) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # inference_params + params = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + sub_model_should_be_defined = True + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + elif passed_class_obj[name] is None: + logger.warning( + f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" + f" that this might lead to problems when using {pipeline_class} and is not recommended." + ) + sub_model_should_be_defined = False + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = import_flax_or_no_model(pipeline_module, class_name) + + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = import_flax_or_no_model(library, class_name) + + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + if loaded_sub_model is None and sub_model_should_be_defined: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loadable_folder = os.path.join(cached_folder, name) + else: + loaded_sub_model = cached_folder + + if issubclass(class_obj, FlaxModelMixin): + loaded_sub_model, loaded_params = load_method( + loadable_folder, + from_pt=from_pt, + use_memory_efficient_attention=use_memory_efficient_attention, + dtype=dtype, + ) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + if from_pt: + # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here + loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) + loaded_params = loaded_sub_model.params + del loaded_sub_model._params + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) + params[name] = loaded_params + elif issubclass(class_obj, FlaxSchedulerMixin): + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state + else: + loaded_sub_model = load_method(loadable_folder) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + model = pipeline_class(**init_kwargs, dtype=dtype) + return model, params + + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... FlaxStableDiffusionPipeline, + ... FlaxStableDiffusionImg2ImgPipeline, + ... ) + + >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16 + ... ) + >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) + ``` + + Returns: + A dictionary containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # TODO: make it compatible with jax.lax + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/diffusers/pipelines/pipeline_utils.py b/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad52c6ac1c59ad1b7855f0208c1549123d2623fd --- /dev/null +++ b/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,1566 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import importlib +import inspect +import os +import re +import sys +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from huggingface_hub import hf_hub_download, model_info, snapshot_download +from packaging import version +from requests.exceptions import HTTPError +from tqdm.auto import tqdm + +import diffusers + +from .. import __version__ +from ..configuration_utils import ConfigMixin +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from ..utils import ( + CONFIG_NAME, + DEPRECATED_REVISION_ARGS, + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + BaseOutput, + deprecate, + get_class_from_dynamic_module, + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + is_safetensors_available, + is_torch_version, + is_transformers_available, + logging, + numpy_to_pil, +) + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel + from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME + from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME + from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME + +from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME + + +if is_accelerate_available(): + import accelerate + + +INDEX_FILE = "diffusion_pytorch_model.bin" +CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class AudioPipelineOutput(BaseOutput): + """ + Output class for audio pipelines. + + Args: + audios (`np.ndarray`) + List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`. + """ + + audios: np.ndarray + + +def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: + """ + Checking for safetensors compatibility: + - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch + files to know which safetensors files are needed. + - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. + + Converting default pytorch serialized filenames to safetensors serialized filenames: + - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" + - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" + extension is replaced with ".safetensors" + """ + pt_filenames = [] + + sf_filenames = set() + + passed_components = passed_components or [] + + for filename in filenames: + _, extension = os.path.splitext(filename) + + if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: + continue + + if extension == ".bin": + pt_filenames.append(filename) + elif extension == ".safetensors": + sf_filenames.add(filename) + + for filename in pt_filenames: + # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' + path, filename = os.path.split(filename) + filename, extension = os.path.splitext(filename) + + if filename.startswith("pytorch_model"): + filename = filename.replace("pytorch_model", "model") + else: + filename = filename + + expected_sf_filename = os.path.join(path, filename) + expected_sf_filename = f"{expected_sf_filename}.safetensors" + + if expected_sf_filename not in sf_filenames: + logger.warning(f"{expected_sf_filename} not found") + return False + + return True + + +def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + # model_pytorch, diffusion_model_pytorch, ... + weight_prefixes = [w.split(".")[0] for w in weight_names] + # .bin, .safetensors, ... + weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = r"\d{5}-of-\d{5}" + + if variant is not None: + # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` + variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.fp16.json` + variant_index_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` + non_variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.json` + non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") + + if variant is not None: + variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} + variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} + variant_filenames = variant_weights | variant_indexes + else: + variant_filenames = set() + + non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} + non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} + non_variant_filenames = non_variant_weights | non_variant_indexes + + # all variant filenames will be used by default + usable_filenames = set(variant_filenames) + + def convert_to_variant(filename): + if "index" in filename: + variant_filename = filename.replace("index", f"index.{variant}") + elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: + variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + else: + variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" + return variant_filename + + for f in non_variant_filenames: + variant_filename = convert_to_variant(f) + if variant_filename not in usable_filenames: + usable_filenames.add(f) + + return usable_filenames, variant_filenames + + +def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames): + info = model_info( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + revision=None, + ) + filenames = {sibling.rfilename for sibling in info.siblings} + comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) + comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] + + if set(comp_model_filenames) == set(model_filenames): + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + else: + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", + FutureWarning, + ) + + +def maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module +): + """Simple helper method to raise or warn in case incorrect module has been passed""" + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + model_cls = sub_model.__class__ + if is_compiled_module(sub_model): + model_cls = sub_model._orig_mod.__class__ + + if not issubclass(model_cls, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" + ) + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + +def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module): + """Simple helper method to retrieve class object of module as well as potential parent class objects""" + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + + class_obj = getattr(pipeline_module, class_name) + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + return class_obj, class_candidates + + +def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None): + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + return get_class_from_dynamic_module( + custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision + ) + + if class_obj != DiffusionPipeline: + return class_obj + + diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) + return getattr(diffusers_module, config["_class_name"]) + + +def load_sub_model( + library_name: str, + class_name: str, + importable_classes: List[Any], + pipelines: Any, + is_pipeline_module: bool, + pipeline_class: Any, + torch_dtype: torch.dtype, + provider: Any, + sess_options: Any, + device_map: Optional[Union[Dict[str, torch.device], str]], + max_memory: Optional[Dict[Union[int, str], Union[int, str]]], + offload_folder: Optional[Union[str, os.PathLike]], + offload_state_dict: bool, + model_variants: Dict[str, str], + name: str, + from_flax: bool, + variant: str, + low_cpu_mem_usage: bool, + cached_folder: Union[str, os.PathLike], +): + """Helper method to load the module `name` from `library_name` and `class_name`""" + # retrieve class candidates + class_obj, class_candidates = get_class_obj_and_candidates( + library_name, class_name, importable_classes, pipelines, is_pipeline_module + ) + + load_method_name = None + # retrive load method name + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + # if load method name is None, then we have a dummy module -> raise Error + if load_method_name is None: + none_module = class_obj.__module__ + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + + # add kwargs to loading method + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + loading_kwargs["max_memory"] = max_memory + loading_kwargs["offload_folder"] = offload_folder + loading_kwargs["offload_state_dict"] = offload_state_dict + loading_kwargs["variant"] = model_variants.pop(name, None) + if from_flax: + loading_kwargs["from_flax"] = True + + # the following can be deleted once the minimum required `transformers` version + # is higher than 4.27 + if ( + is_transformers_model + and loading_kwargs["variant"] is not None + and transformers_version < version.parse("4.27.0") + ): + raise ImportError( + f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" + ) + elif is_transformers_model and loading_kwargs["variant"] is None: + loading_kwargs.pop("variant") + + # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` + if not (from_flax and is_transformers_model): + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + else: + loading_kwargs["low_cpu_mem_usage"] = False + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + return loaded_sub_model + + +class DiffusionPipeline(ConfigMixin): + r""" + Base class for all pipelines. + + [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and + provides methods for loading, downloading and saving models. It also includes methods to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + - **_optional_components** (List[`str`]) -- List of all optional components that don't have to be passed to the + pipeline to function (should be overridden by subclasses). + """ + config_name = "model_index.json" + _optional_components = [] + _exclude_from_cpu_offload = [] + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + if module is None: + register_dict = {name: (None, None)} + else: + # register the config from the original module, not the dynamo compiled one + if is_compiled_module(module): + not_compiled_module = module._orig_mod + else: + not_compiled_module = module + + library = not_compiled_module.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ + + # retrieve class_name + class_name = not_compiled_module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def __setattr__(self, name: str, value: Any): + if name in self.__dict__ and hasattr(self.config, name): + # We need to overwrite the config if name exists in config + if isinstance(getattr(self.config, name), (tuple, list)): + if value is not None and self.config[name][0] is not None: + class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) + else: + class_library_tuple = (None, None) + + self.register_to_config(**{name: class_library_tuple}) + else: + self.register_to_config(**{name: value}) + + super().__setattr__(name, value) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + """ + Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its + class implements both a save and loading method. The pipeline is easily reloaded using the + [`~DiffusionPipeline.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a pipeline to. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + """ + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name", None) + model_index_dict.pop("_diffusers_version", None) + model_index_dict.pop("_module", None) + model_index_dict.pop("_name_or_path", None) + + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + if is_compiled_module(sub_model): + sub_model = sub_model._orig_mod + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + if library_name in sys.modules: + library = importlib.import_module(library_name) + else: + logger.info( + f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" + ) + + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + if save_method_name is None: + logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.") + # make sure that unsaveable components are not tried to be loaded afterward + self.register_to_config(**{pipeline_component_name: (None, None)}) + continue + + save_method = getattr(sub_model, save_method_name) + + # Call the save method with the argument safe_serialization only if it's supported + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + save_method_accept_variant = "variant" in save_method_signature.parameters + + save_kwargs = {} + if save_method_accept_safe: + save_kwargs["safe_serialization"] = safe_serialization + if save_method_accept_variant: + save_kwargs["variant"] = variant + + save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) + + # finally save the config + self.save_config(save_directory) + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + silence_dtype_warnings: bool = False, + ): + if torch_device is None and torch_dtype is None: + return self + + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + return hasattr(module, "_hf_hook") and not isinstance( + module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook) + ) + + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." + ) + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + + if is_loaded_in_8bit and torch_dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." + ) + + if is_loaded_in_8bit and torch_device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." + ) + else: + module.to(torch_device, torch_dtype) + + if ( + module.dtype == torch.float16 + and str(torch_device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. + + The pipeline is set in evaluation mode (`model.eval()`) by default. + + If you get the error message below, you need to finetune the weights for your downstream task: + + ``` + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights + saved using + [`~DiffusionPipeline.save_pretrained`]. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the + dtype is automatically derived from the model's weights. + custom_pipeline (`str`, *optional*): + + + + 🧪 This is an experimental feature and may change in the future. + + + + Can be either: + + - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom + pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines + the custom pipeline. + - A string, the *file name* of a community pipeline hosted on GitHub under + [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file + names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` + instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the + current main branch of GitHub. + - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory + must contain a file called `pipeline.py` that defines the custom pipeline. + + + For more information on how to load and create custom pipelines, please have a look at [Loading and + Adding Custom + Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) + + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn’t need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if device_map contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + + + To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with + `huggingface-cli login`. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + + >>> # Use a different scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + torch_dtype = kwargs.pop("torch_dtype", None) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + cached_folder = cls.download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + from_flax=from_flax, + use_safetensors=use_safetensors, + custom_pipeline=custom_pipeline, + custom_revision=custom_revision, + variant=variant, + **kwargs, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) + + # 2. Define which model components should load variants + # We retrieve the information by matching whether variant + # model checkpoints exist in the subfolders + model_variants = {} + if variant is not None: + for folder in os.listdir(cached_folder): + folder_path = os.path.join(cached_folder, folder) + is_folder = os.path.isdir(folder_path) and folder in config_dict + variant_exists = is_folder and any( + p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) + ) + if variant_exists: + model_variants[folder] = variant + + # 3. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + pipeline_class = _get_pipeline_class( + cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision + ) + + # DEPRECATED: To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config_dict["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + + # 4. Define expected modules given pipeline signature + # and define non-None initialized modules (=`init_kwargs`) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + # Special case: safety_checker must be loaded separately when using `from_flax` + if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: + raise NotImplementedError( + "The safety checker cannot be automatically loaded when loading weights `from_flax`." + " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" + " separately if you need it." + ) + + # 5. Throw nice warnings / errors for fast accelerate loading + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # import it here to avoid circular import + from diffusers import pipelines + + # 6. Load each module in the pipeline + for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."): + # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + if class_name.startswith("Flax"): + class_name = class_name[4:] + + # 6.2 Define all importable classes + is_pipeline_module = hasattr(pipelines, library_name) + importable_classes = ALL_IMPORTABLE_CLASSES + loaded_sub_model = None + + # 6.3 Use passed sub model or load class_name from library_name + if name in passed_class_obj: + # if the model is in a pipeline module, then we load it from the pipeline + # check that passed_class_obj has correct parent class + maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module + ) + + loaded_sub_model = passed_class_obj[name] + else: + # load sub model + loaded_sub_model = load_sub_model( + library_name=library_name, + class_name=class_name, + importable_classes=importable_classes, + pipelines=pipelines, + is_pipeline_module=is_pipeline_module, + pipeline_class=pipeline_class, + torch_dtype=torch_dtype, + provider=provider, + sess_options=sess_options, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + model_variants=model_variants, + name=name, + from_flax=from_flax, + variant=variant, + low_cpu_mem_usage=low_cpu_mem_usage, + cached_folder=cached_folder, + ) + logger.info( + f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." + ) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 7. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + # 8. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + + # 9. Save where the model was instantiated from + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + return model + + @property + def name_or_path(self) -> str: + return getattr(self.config, "_name_or_path", None) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + if device == "cuda": + device = torch.device(f"{device}:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if name in self._exclude_from_cpu_offload: + model.to(device) + else: + # make sure to offload buffers if not all high level weights + # are of type nn.Module + offload_buffers = len(model._parameters) > 0 + cpu_offload(model, device, offload_buffers=offload_buffers) + + @classmethod + def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: + r""" + Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. + + Parameters: + pretrained_model_name (`str` or `os.PathLike`, *optional*): + A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + custom_pipeline (`str`, *optional*): + Can be either: + + - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained + pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines + the custom pipeline. + + - A string, the *file name* of a community pipeline hosted on GitHub under + [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file + names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` + instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the + current `main` branch of GitHub. + + - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory + must contain a file called `pipeline.py` that defines the custom pipeline. + + + + 🧪 This is an experimental feature and may change in the future. + + + + For more information on how to load and create custom pipelines, take a look at [How to contribute a + community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline). + + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + custom_revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a + custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + variant (`str`, *optional*): + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + Returns: + `os.PathLike`: + A path to the downloaded pipeline. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + pipeline_is_cached = False + allow_patterns = None + ignore_patterns = None + + if not local_files_only: + try: + info = model_info( + pretrained_model_name, + use_auth_token=use_auth_token, + revision=revision, + ) + except HTTPError as e: + logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") + local_files_only = True + + if not local_files_only: + config_file = hf_hub_download( + pretrained_model_name, + cls.config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + resume_download=resume_download, + use_auth_token=use_auth_token, + ) + + config_dict = cls._dict_from_json_file(config_file) + + ignore_filenames = config_dict.pop("_ignore_files", []) + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] + + filenames = {sibling.rfilename for sibling in info.siblings} + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + + if len(variant_filenames) == 0 and variant is not None: + deprecation_message = ( + f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" + "if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant" + "modeling files is deprecated." + ) + deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False) + + # remove ignored filenames + model_filenames = set(model_filenames) - set(ignore_filenames) + variant_filenames = set(variant_filenames) - set(ignore_filenames) + + # if the whole pipeline is cached we don't have to ping the Hub + if revision in DEPRECATED_REVISION_ARGS and version.parse( + version.parse(__version__).base_version + ) >= version.parse("0.20.0"): + warn_deprecated_model_variant( + pretrained_model_name, use_auth_token, variant, revision, model_filenames + ) + + model_folder_names = {os.path.split(f)[0] for f in model_filenames} + + # all filenames compatible with variant will be added + allow_patterns = list(model_filenames) + + # allow all patterns from non-model folders + # this enables downloading schedulers, tokenizers, ... + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] + # also allow downloading config.json files with the model + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] + + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + CUSTOM_PIPELINE_FILE_NAME, + ] + + # retrieve passed components that should not be downloaded + pipeline_class = _get_pipeline_class( + cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision + ) + expected_components, _ = cls._get_signature_keys(pipeline_class) + passed_components = [k for k in expected_components if k in kwargs] + + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ) + ): + raise EnvironmentError( + f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + elif use_safetensors and is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ): + ignore_patterns = ["*.bin", "*.msgpack"] + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if ( + len(safetensors_variant_filenames) > 0 + and safetensors_model_filenames != safetensors_variant_filenames + ): + logger.warn( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warn( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + + # Don't download any objects that are passed + allow_patterns = [ + p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) + ] + # Don't download index files of forbidden patterns either + ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] + + re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] + re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] + + expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] + expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) + + if pipeline_is_cached and not force_download: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + user_agent = {"pipeline_class": cls.__name__} + if custom_pipeline is not None and not custom_pipeline.endswith(".py"): + user_agent["custom_pipeline"] = custom_pipeline + + # download all allow_patterns - ignore_patterns + cached_folder = snapshot_download( + pretrained_model_name, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + return cached_folder + + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + The `self.components` property can be useful to run different pipelines with the same weights and + configurations without reallocating additional memory. + + Returns (`dict`): + A dictionary containing all the modules needed to initialize the pipeline. + + Examples: + + ```py + >>> from diffusers import ( + ... StableDiffusionPipeline, + ... StableDiffusionImg2ImgPipeline, + ... StableDiffusionInpaintPipeline, + ... ) + + >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) + >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) + ``` + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components.keys()} are defined." + ) + + return components + + @staticmethod + def numpy_to_pil(images): + """ + Convert a NumPy image or a batch of images to a PIL image. + """ + return numpy_to_pil(images) + + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up during + inference. Speed up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + >>> # Workaround for not accepting attention shape using VAE for Flash Attention + >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + fn_recursive_set_mem_eff(module) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + self.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is + computed in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def set_attention_slice(self, slice_size: Optional[int]): + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] + + for module in modules: + module.set_attention_slice(slice_size) diff --git a/diffusers/pipelines/pndm/__init__.py b/diffusers/pipelines/pndm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..488eb4f5f2b29c071fdc044ef282bc2838148c1e --- /dev/null +++ b/diffusers/pipelines/pndm/__init__.py @@ -0,0 +1 @@ +from .pipeline_pndm import PNDMPipeline diff --git a/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eecc3a7d6435cca7561a6ff05c97142fb0769593 Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72c740b00b74c5228132938a010c4ed6a3d30d2d Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df3cd0ece13264032940cd172d4b7df3728c7552 Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc differ diff --git a/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4625b6034a7f31c6dce792c4d0be8c2aac13a419 Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc differ diff --git a/diffusers/pipelines/pndm/pipeline_pndm.py b/diffusers/pipelines/pndm/pipeline_pndm.py new file mode 100644 index 0000000000000000000000000000000000000000..361444079311ad87eb53fc41f02643c4f4bf3c93 --- /dev/null +++ b/diffusers/pipelines/pndm/pipeline_pndm.py @@ -0,0 +1,99 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import PNDMScheduler +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class PNDMPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): + super().__init__() + + scheduler = PNDMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, `optional`, defaults to 1): The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): A [torch + generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose + between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a + [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf + + # Sample gaussian noise to begin loop + image = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + device=self.device, + ) + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + model_output = self.unet(image, t).sample + + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/repaint/__init__.py b/diffusers/pipelines/repaint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16bc86d1cedf6243fb92f7ba331b5a6188133298 --- /dev/null +++ b/diffusers/pipelines/repaint/__init__.py @@ -0,0 +1 @@ +from .pipeline_repaint import RePaintPipeline diff --git a/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bddf43c6760aeb888c6fe703a561ef126862f5a Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/repaint/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2157b1f6bde68443aee4bd152f240edcf6a28f90 Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..710eb6c5ceed6eff8ab7549474d59ae0f3bb4aff Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc differ diff --git a/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-38.pyc b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9868d7136f5640556b5e1f197bf4e48220420e8c Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-38.pyc differ diff --git a/diffusers/pipelines/repaint/pipeline_repaint.py b/diffusers/pipelines/repaint/pipeline_repaint.py new file mode 100644 index 0000000000000000000000000000000000000000..6527a023a74f080f17f0943c7cfd9e446d64d3e2 --- /dev/null +++ b/diffusers/pipelines/repaint/pipeline_repaint.py @@ -0,0 +1,177 @@ +# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...models import UNet2DModel +from ...schedulers import RePaintScheduler +from ...utils import PIL_INTERPOLATION, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): + if isinstance(mask, torch.Tensor): + return mask + elif isinstance(mask, PIL.Image.Image): + mask = [mask] + + if isinstance(mask[0], PIL.Image.Image): + w, h = mask[0].size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] + mask = np.concatenate(mask, axis=0) + mask = mask.astype(np.float32) / 255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.cat(mask, dim=0) + return mask + + +class RePaintPipeline(DiffusionPipeline): + unet: UNet2DModel + scheduler: RePaintScheduler + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image], + num_inference_steps: int = 250, + eta: float = 0.0, + jump_length: int = 10, + jump_n_sample: int = 10, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image to inpaint on. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + The mask_image where 0.0 values define which part of the original image to inpaint (change). + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`): + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM + and 1.0 is DDPM scheduler respectively. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times we will make forward time jump for a given chosen time sample. Take a look at + Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + original_image = image + + original_image = _preprocess_image(original_image) + original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype) + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype) + + batch_size = original_image.shape[0] + + # sample gaussian noise to begin the loop + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image_shape = original_image.shape + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device) + self.scheduler.eta = eta + + t_last = self.scheduler.timesteps[0] + 1 + generator = generator[0] if isinstance(generator, list) else generator + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + if t < t_last: + # predict the noise residual + model_output = self.unet(image, t).sample + # compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample + + else: + # compute the reverse: x_t-1 -> x_t + image = self.scheduler.undo_step(image, t_last, generator) + t_last = t + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/score_sde_ve/__init__.py b/diffusers/pipelines/score_sde_ve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c2a85c067b707c155e78a3c8b84562999134e7 --- /dev/null +++ b/diffusers/pipelines/score_sde_ve/__init__.py @@ -0,0 +1 @@ +from .pipeline_score_sde_ve import ScoreSdeVePipeline diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efda4baeb6ee8efd7408452fc9fcc8c90eb1189 Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1785560c9ce94a4305d84023e145bc25b2671450 Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..789757be70f99b234225666e72571e09ce9cefdb Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc differ diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3805b5a27912b4fcbc5952bd72d2e1d2449f7910 Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc differ diff --git a/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff7b8ee460b58f98c4bd767f70946dc4da2a893 --- /dev/null +++ b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -0,0 +1,101 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import ScoreSdeVeScheduler +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class ScoreSdeVePipeline(DiffusionPipeline): + r""" + Parameters: + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): + The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. + """ + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) + + # correction step + for _ in range(self.scheduler.config.correct_steps): + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample + + # prediction step + model_output = model(sample, sigma_t).sample + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) + + sample, sample_mean = output.prev_sample, output.prev_sample_mean + + sample = sample_mean.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/diffusers/pipelines/semantic_stable_diffusion/__init__.py b/diffusers/pipelines/semantic_stable_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e312c5e30138e106930421ad8c55c23f01e60e7 --- /dev/null +++ b/diffusers/pipelines/semantic_stable_diffusion/__init__.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +class SemanticStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96ca9446d7e783e2ce9eb7bd761e83e795691ca1 Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e932685eab8d3751a447cf670722a421d2f8347b Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54061b7e030cd037550c42216559ad66f4b94a68 Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-38.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e164125bb2fb7ff83dd49d33765627959a11ff6b Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..911a5018de18de505323420f4220551d2b4f8624 --- /dev/null +++ b/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -0,0 +1,724 @@ +import inspect +import warnings +from itertools import repeat +from typing import Callable, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import SemanticStableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SemanticStableDiffusionPipeline + + >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> out = pipe( + ... prompt="a photo of the face of a woman", + ... num_images_per_prompt=1, + ... guidance_scale=7, + ... editing_prompt=[ + ... "smiling, smile", # Concepts to apply + ... "glasses, wearing glasses", + ... "curls, wavy hair, curly hair", + ... "beard, full beard, mustache", + ... ], + ... reverse_editing_direction=[ + ... False, + ... False, + ... False, + ... False, + ... ], # Direction of guidance i.e. increase all concepts + ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept + ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + ... edit_threshold=[ + ... 0.99, + ... 0.975, + ... 0.925, + ... 0.96, + ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + ... edit_mom_beta=0.6, # Momentum beta + ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other + ... ) + >>> image = out.images[0] + ``` +""" + + +class SemanticStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation with latent editing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + This model builds on the implementation of ['StableDiffusionPipeline'] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`Q16SafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_embeddings: Optional[torch.Tensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 10, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + edit_momentum_scale: Optional[float] = 0.1, + edit_mom_beta: Optional[float] = 0.4, + edit_weights: Optional[List[float]] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeddings (`torch.Tensor>`, *optional*): + Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be + specified via `reverse_editing_direction`. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. + `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). + edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum + will still be calculated for those steps and applied once all warmup periods are over. + `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). + edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): + Threshold of semantic guidance. + edit_momentum_scale (`float`, *optional*, defaults to 0.1): + Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 + momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are + finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). + edit_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous + momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). + edit_weights (`List[float]`, *optional*, defaults to `None`): + Indicates how much each individual concept should influence the overall guidance. If no weights are + provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA + Paper](https://arxiv.org/pdf/2301.12247.pdf). + sem_guidance (`List[torch.Tensor]`, *optional*): + List of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + + Returns: + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True, + otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the + second element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeddings is not None: + enable_edit_guidance = True + enabled_editing_prompts = editing_prompt_embeddings.shape[0] + else: + enabled_editing_prompts = 0 + enable_edit_guidance = False + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if enable_edit_guidance: + # get safety text embeddings + if editing_prompt_embeddings is None: + edit_concepts_input = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + + edit_concepts_input_ids = edit_concepts_input.input_ids + + if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode( + edit_concepts_input_ids[:, self.tokenizer.model_max_length :] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] + edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] + else: + edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed_edit, seq_len_edit, _ = edit_concepts.shape + edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) + edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if enable_edit_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # get the initial random noise unless the user supplied it + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + self.device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Initialize edit_momentum to None + edit_momentum = None + + self.uncond_estimates = None + self.text_estimates = None + self.edit_estimates = None + self.sem_guidance = None + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + noise_pred_edit_concepts = noise_pred_out[2:] + + # default text guidance + noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) + # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) + + if self.uncond_estimates is None: + self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) + self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() + + if self.text_estimates is None: + self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + self.text_estimates[i] = noise_pred_text.detach().cpu() + + if self.edit_estimates is None and enable_edit_guidance: + self.edit_estimates = torch.zeros( + (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) + ) + + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + + if edit_momentum is None: + edit_momentum = torch.zeros_like(noise_guidance) + + if enable_edit_guidance: + concept_weights = torch.zeros( + (len(noise_pred_edit_concepts), noise_guidance.shape[0]), + device=self.device, + dtype=noise_guidance.dtype, + ) + noise_guidance_edit = torch.zeros( + (len(noise_pred_edit_concepts), *noise_guidance.shape), + device=self.device, + dtype=noise_guidance.dtype, + ) + # noise_guidance_edit = torch.zeros_like(noise_guidance) + warmup_inds = [] + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + self.edit_estimates[i, c] = noise_pred_edit_concept + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + if edit_weights: + edit_weight_c = edit_weights[c] + else: + edit_weight_c = 1.0 + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + if i >= edit_warmup_steps_c: + warmup_inds.append(c) + if i >= edit_cooldown_steps_c: + noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + + tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + concept_weights[c, :] = tmp_weights + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp.dtype == torch.float32: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp.dtype) + + noise_guidance_edit_tmp = torch.where( + torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp + + # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp + + warmup_inds = torch.tensor(warmup_inds).to(self.device) + if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: + concept_weights = concept_weights.to("cpu") # Offload to cpu + noise_guidance_edit = noise_guidance_edit.to("cpu") + + concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) + concept_weights_tmp = torch.where( + concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp + ) + concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) + # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) + + noise_guidance_edit_tmp = torch.index_select( + noise_guidance_edit.to(self.device), 0, warmup_inds + ) + noise_guidance_edit_tmp = torch.einsum( + "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp + ) + noise_guidance_edit_tmp = noise_guidance_edit_tmp + noise_guidance = noise_guidance + noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() + + del noise_guidance_edit_tmp + del concept_weights_tmp + concept_weights = concept_weights.to(self.device) + noise_guidance_edit = noise_guidance_edit.to(self.device) + + concept_weights = torch.where( + concept_weights < 0, torch.zeros_like(concept_weights), concept_weights + ) + + concept_weights = torch.nan_to_num(concept_weights) + + noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) + + noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum + + edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit + + if warmup_inds.shape[0] == len(noise_pred_edit_concepts): + noise_guidance = noise_guidance + noise_guidance_edit + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + if sem_guidance is not None: + edit_guidance = sem_guidance[i].to(self.device) + noise_guidance = noise_guidance + edit_guidance + + noise_pred = noise_pred_uncond + noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/shap_e/__init__.py b/diffusers/pipelines/shap_e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04aa1f2f6d7852877e4c7f8b07cd15a8d1d496f5 --- /dev/null +++ b/diffusers/pipelines/shap_e/__init__.py @@ -0,0 +1,27 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .camera import create_pan_cameras + from .pipeline_shap_e import ShapEPipeline + from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline + from .renderer import ( + BoundingBoxVolume, + ImportanceRaySampler, + MLPNeRFModelOutput, + MLPNeRSTFModel, + ShapEParamsProjModel, + ShapERenderer, + StratifiedRaySampler, + VoidNeRFModel, + ) diff --git a/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79af1c2b64a6bfd6b8a00f69c2ea209434552eb Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71bb328f3133c9297127006c84e0190dc490c853 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/camera.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3271468d32ceff0cfdd4baf905e42bd1eb50da45 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-310.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/camera.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54a30e960a8dde49f46b2194739af5ad2fa38e57 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-38.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12f4d8adf926e89e7790347a075c258965eff465 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-310.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..930a2d9566c7b4ca233cdb3cef903baaba357e50 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-38.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d7ae7e7f1d0e3d91249a0a7dc4c88893db54f1 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb320ee0b0b8de9dd76886e0451513550d65256d Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2540fc9bf3b0aff5f07a52fe9c768cc903d917f Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-310.pyc differ diff --git a/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca37932bafab2a20951d4df7dc3a4970d046b2d3 Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-38.pyc differ diff --git a/diffusers/pipelines/shap_e/camera.py b/diffusers/pipelines/shap_e/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef0d66070223a80eed59da8d842389fed0c7aef --- /dev/null +++ b/diffusers/pipelines/shap_e/camera.py @@ -0,0 +1,147 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import torch + + +@dataclass +class DifferentiableProjectiveCamera: + """ + Implements a batch, differentiable, standard pinhole camera + """ + + origin: torch.Tensor # [batch_size x 3] + x: torch.Tensor # [batch_size x 3] + y: torch.Tensor # [batch_size x 3] + z: torch.Tensor # [batch_size x 3] + width: int + height: int + x_fov: float + y_fov: float + shape: Tuple[int] + + def __post_init__(self): + assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] + assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 + assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 + + def resolution(self): + return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) + + def fov(self): + return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) + + def get_image_coords(self) -> torch.Tensor: + """ + :return: coords of shape (width * height, 2) + """ + pixel_indices = torch.arange(self.height * self.width) + coords = torch.stack( + [ + pixel_indices % self.width, + torch.div(pixel_indices, self.width, rounding_mode="trunc"), + ], + axis=1, + ) + return coords + + @property + def camera_rays(self): + batch_size, *inner_shape = self.shape + inner_batch_size = int(np.prod(inner_shape)) + + coords = self.get_image_coords() + coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) + rays = self.get_camera_rays(coords) + + rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) + + return rays + + def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + batch_size, *shape, n_coords = coords.shape + assert n_coords == 2 + assert batch_size == self.origin.shape[0] + + flat = coords.view(batch_size, -1, 2) + + res = self.resolution() + fov = self.fov() + + fracs = (flat.float() / (res - 1)) * 2 - 1 + fracs = fracs * torch.tan(fov / 2) + + fracs = fracs.view(batch_size, -1, 2) + directions = ( + self.z.view(batch_size, 1, 3) + + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays = torch.stack( + [ + torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), + directions, + ], + dim=2, + ) + return rays.view(batch_size, *shape, 2, 3) + + def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return DifferentiableProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + +def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: + origins = [] + xs = [] + ys = [] + zs = [] + for theta in np.linspace(0, 2 * np.pi, num=20): + z = np.array([np.sin(theta), np.cos(theta), -0.5]) + z /= np.sqrt(np.sum(z**2)) + origin = -z * 4 + x = np.array([np.cos(theta), -np.sin(theta), 0.0]) + y = np.cross(z, x) + origins.append(origin) + xs.append(x) + ys.append(y) + zs.append(z) + return DifferentiableProjectiveCamera( + origin=torch.from_numpy(np.stack(origins, axis=0)).float(), + x=torch.from_numpy(np.stack(xs, axis=0)).float(), + y=torch.from_numpy(np.stack(ys, axis=0)).float(), + z=torch.from_numpy(np.stack(zs, axis=0)).float(), + width=size, + height=size, + x_fov=0.7, + y_fov=0.7, + shape=(1, len(xs)), + ) diff --git a/diffusers/pipelines/shap_e/pipeline_shap_e.py b/diffusers/pipelines/shap_e/pipeline_shap_e.py new file mode 100644 index 0000000000000000000000000000000000000000..fdcbe550860a403d2869321affd9045249472dad --- /dev/null +++ b/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -0,0 +1,354 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .renderer import ShapERenderer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 15.0 + >>> prompt = "a shark" + + >>> images = pipe( + ... prompt, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "shark_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for ShapEPipeline. + + Args: + images (`torch.FloatTensor`) + a list of images for 3D rendering + """ + + images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]] + + +class ShapEPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects + with the NeRF rendering method + """ + + def __init__( + self, + prior: PriorTransformer, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: HeunDiscreteScheduler, + renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + renderer=renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + len(prompt) if isinstance(prompt, list) else 1 + + # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file + self.tokenizer.pad_token_id = 0 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + prompt_embeds = text_encoder_output.text_embeds + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # in Shap-E it normalize the prompt_embeds and then later rescale it + prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # Rescale the features to have unit variance + prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds + + return prompt_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: Optional[str] = "pil", # pil, np, latent + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + frame_size (`int`, *optional*, default to 64): + the width and height of each image frame of the generated 3d output + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`ShapEPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=prompt_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance is not None: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + for i, latent in enumerate(latents): + image = self.renderer.decode( + latent[None, :], + device, + size=frame_size, + ray_batch_size=4096, + n_coarse_samples=64, + n_fine_samples=128, + ) + images.append(image) + + images = torch.stack(images) + + if output_type not in ["np", "pil"]: + raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..08c585c5ad736d6f12397409187dc77659b771e3 --- /dev/null +++ b/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -0,0 +1,312 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPVisionModel + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + logging, + randn_tensor, + replace_example_docstring, +) +from .renderer import ShapERenderer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e-img2img" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 3.0 + >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" + >>> image = load_image(image_url).convert("RGB") + + >>> images = pipe( + ... image, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for ShapEPipeline. + + Args: + images (`torch.FloatTensor`) + a list of images for 3D rendering + """ + + images: Union[PIL.Image.Image, np.ndarray] + + +class ShapEImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects + with the NeRF rendering method + """ + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + scheduler: HeunDiscreteScheduler, + renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + image_encoder=image_encoder, + image_processor=image_processor, + scheduler=scheduler, + renderer=renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_image( + self, + image, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + if isinstance(image, List) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) + + image = image.to(dtype=self.image_encoder.dtype, device=device) + + image_embeds = self.image_encoder(image)["last_hidden_state"] + image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 + + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image]], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: Optional[str] = "pil", # pil, np, latent + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + frame_size (`int`, *optional*, default to 64): + the width and height of each image frame of the generated 3d output + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`ShapEPipelineOutput`] or `tuple` + """ + + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): + batch_size = len(image) + else: + raise ValueError( + f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}" + ) + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=image_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance is not None: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + for i, latent in enumerate(latents): + print() + image = self.renderer.decode( + latent[None, :], + device, + size=frame_size, + ray_batch_size=4096, + n_coarse_samples=64, + n_fine_samples=128, + ) + + images.append(image) + + images = torch.stack(images) + + if output_type not in ["np", "pil"]: + raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/diffusers/pipelines/shap_e/renderer.py b/diffusers/pipelines/shap_e/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..8b075e671f63d9f6cbddcfb205df1ba38a426e6f --- /dev/null +++ b/diffusers/pipelines/shap_e/renderer.py @@ -0,0 +1,709 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...utils import BaseOutput +from .camera import create_pan_cameras + + +def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: + r""" + Sample from the given discrete probability distribution with replacement. + + The i-th bin is assumed to have mass pmf[i]. + + Args: + pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() + n_samples: number of samples + + Return: + indices sampled with replacement + """ + + *shape, support_size, last_dim = pmf.shape + assert last_dim == 1 + + cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) + inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) + + return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) + + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://arxiv.org/pdf/2210.04628.pdf + """ + if min_deg == max_deg: + return x + + scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + + +def encode_position(position): + return posenc_nerf(position, min_deg=0, max_deg=15) + + +def encode_direction(position, direction=None): + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + else: + return posenc_nerf(direction, min_deg=0, max_deg=8) + + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") + + +def integrate_samples(volume_range, ts, density, channels): + r""" + Function integrating the model output. + + Args: + volume_range: Specifies the integral range [t0, t1] + ts: timesteps + density: torch.Tensor [batch_size, *shape, n_samples, 1] + channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] + returns: + channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density + *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume + ) + """ + + # 1. Calculate the weights + _, _, dt = volume_range.partition(ts) + ddensity = density * dt + + mass = torch.cumsum(ddensity, dim=-2) + transmittance = torch.exp(-mass[..., -1, :]) + + alphas = 1.0 - torch.exp(-ddensity) + Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) + # This is the probability of light hitting and reflecting off of + # something at depth [..., i, :]. + weights = alphas * Ts + + # 2. Integrate channels + channels = torch.sum(channels * weights, dim=-2) + + return channels, weights, transmittance + + +class VoidNeRFModel(nn.Module): + """ + Implements the default empty space model where all queries are rendered as background. + """ + + def __init__(self, background, channel_scale=255.0): + super().__init__() + background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) + + self.register_buffer("background", background) + + def forward(self, position): + background = self.background[None].to(position.device) + + shape = position.shape[:-1] + ones = [1] * (len(shape) - 1) + n_channels = background.shape[-1] + background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) + + return background + + +@dataclass +class VolumeRange: + t0: torch.Tensor + t1: torch.Tensor + intersected: torch.Tensor + + def __post_init__(self): + assert self.t0.shape == self.t1.shape == self.intersected.shape + + def partition(self, ts): + """ + Partitions t0 and t1 into n_samples intervals. + + Args: + ts: [batch_size, *shape, n_samples, 1] + + Return: + + lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, + *shape, n_samples, 1] + + where + ts \\in [lower, upper] deltas = upper - lower + """ + + mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 + lower = torch.cat([self.t0[..., None, :], mids], dim=-2) + upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) + delta = upper - lower + assert lower.shape == upper.shape == delta.shape == ts.shape + return lower, upper, delta + + +class BoundingBoxVolume(nn.Module): + """ + Axis-aligned bounding box defined by the two opposite corners. + """ + + def __init__( + self, + *, + bbox_min, + bbox_max, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + ): + """ + Args: + bbox_min: the left/bottommost corner of the bounding box + bbox_max: the other corner of the bounding box + min_dist: all rays should start at least this distance away from the origin. + """ + super().__init__() + + self.min_dist = min_dist + self.min_t_range = min_t_range + + self.bbox_min = torch.tensor(bbox_min) + self.bbox_max = torch.tensor(bbox_max) + self.bbox = torch.stack([self.bbox_min, self.bbox_max]) + assert self.bbox.shape == (2, 3) + assert min_dist >= 0.0 + assert min_t_range > 0.0 + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + epsilon=1e-6, + ): + """ + Args: + origin: [batch_size, *shape, 3] + direction: [batch_size, *shape, 3] + t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + params: Optional meta parameters in case Volume is parametric + epsilon: to stabilize calculations + + Return: + A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with + the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to + be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + ones = [1] * len(shape) + bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) + + def _safe_divide(a, b, epsilon=1e-6): + return a / torch.where(b < 0, b - epsilon, b + epsilon) + + ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) + + # Cases to think about: + # + # 1. t1 <= t0: the ray does not pass through the AABB. + # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. + # 3. t0 <= 0 <= t1: the ray starts from inside the BB + # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. + # + # 1 and 4 are clearly handled from t0 < t1 below. + # Making t0 at least min_dist (>= 0) takes care of 2 and 3. + t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) + t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values + assert t0.shape == t1.shape == (batch_size, *shape, 1) + if t0_lower is not None: + assert t0.shape == t0_lower.shape + t0 = torch.maximum(t0, t0_lower) + + intersected = t0 + self.min_t_range < t1 + t0 = torch.where(intersected, t0, torch.zeros_like(t0)) + t1 = torch.where(intersected, t1, torch.ones_like(t1)) + + return VolumeRange(t0=t0, t1=t1, intersected=intersected) + + +class StratifiedRaySampler(nn.Module): + """ + Instead of fixed intervals, a sample is drawn uniformly at random from each interval. + """ + + def __init__(self, depth_mode: str = "linear"): + """ + :param depth_mode: linear samples ts linearly in depth. harmonic ensures + closer points are sampled more densely. + """ + self.depth_mode = depth_mode + assert self.depth_mode in ("linear", "geometric", "harmonic") + + def sample( + self, + t0: torch.Tensor, + t1: torch.Tensor, + n_samples: int, + epsilon: float = 1e-3, + ) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + ones = [1] * (len(t0.shape) - 1) + ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) + + if self.depth_mode == "linear": + ts = t0 * (1.0 - ts) + t1 * ts + elif self.depth_mode == "geometric": + ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() + elif self.depth_mode == "harmonic": + # The original NeRF recommends this interpolation scheme for + # spherical scenes, but there could be some weird edge cases when + # the observer crosses from the inner to outer volume. + ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) + + mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) + upper = torch.cat([mids, t1], dim=-1) + lower = torch.cat([t0, mids], dim=-1) + # yiyi notes: add a random seed here for testing, don't forget to remove + torch.manual_seed(0) + t_rand = torch.rand_like(ts) + + ts = lower + (upper - lower) * t_rand + return ts.unsqueeze(-1) + + +class ImportanceRaySampler(nn.Module): + """ + Given the initial estimate of densities, this samples more from regions/bins expected to have objects. + """ + + def __init__( + self, + volume_range: VolumeRange, + ts: torch.Tensor, + weights: torch.Tensor, + blur_pool: bool = False, + alpha: float = 1e-5, + ): + """ + Args: + volume_range: the range in which a ray intersects the given volume. + ts: earlier samples from the coarse rendering step + weights: discretized version of density * transmittance + blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. + alpha: small value to add to weights. + """ + self.volume_range = volume_range + self.ts = ts.clone().detach() + self.weights = weights.clone().detach() + self.blur_pool = blur_pool + self.alpha = alpha + + @torch.no_grad() + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + lower, upper, _ = self.volume_range.partition(self.ts) + + batch_size, *shape, n_coarse_samples, _ = self.ts.shape + + weights = self.weights + if self.blur_pool: + padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) + maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) + weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) + weights = weights + self.alpha + pmf = weights / weights.sum(dim=-2, keepdim=True) + inds = sample_pmf(pmf, n_samples) + assert inds.shape == (batch_size, *shape, n_samples, 1) + assert (inds >= 0).all() and (inds < n_coarse_samples).all() + + t_rand = torch.rand(inds.shape, device=inds.device) + lower_ = torch.gather(lower, -2, inds) + upper_ = torch.gather(upper, -2, inds) + + ts = lower_ + (upper_ - lower_) * t_rand + ts = torch.sort(ts, dim=-2).values + return ts + + +@dataclass +class MLPNeRFModelOutput(BaseOutput): + density: torch.Tensor + signed_distance: torch.Tensor + channels: torch.Tensor + ts: torch.Tensor + + +class MLPNeRSTFModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + ): + super().__init__() + + # Instantiate the MLP + + # Find out the dimension of encoded position and direction + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(position=dummy).shape[-1] + d_posenc_dir = encode_direction(position=dummy).shape[-1] + + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos] + mlp_widths + output_widths = mlp_widths + [n_output] + + if insert_direction_at is not None: + input_widths[insert_direction_at] += d_posenc_dir + + self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) + + if act_fn == "swish": + # self.activation = swish + # yiyi testing: + self.activation = lambda x: F.silu(x) + else: + raise ValueError(f"Unsupported activation function {act_fn}") + + self.sdf_activation = torch.tanh + self.density_activation = torch.nn.functional.relu + self.channel_activation = torch.sigmoid + + def map_indices_to_keys(self, output): + h_map = { + "sdf": (0, 1), + "density_coarse": (1, 2), + "density_fine": (2, 3), + "stf": (3, 6), + "nerf_coarse": (6, 9), + "nerf_fine": (9, 12), + } + + mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} + + return mapped_output + + def forward(self, *, position, direction, ts, nerf_level="coarse"): + h = encode_position(position) + + h_preact = h + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.config.insert_direction_at: # 4 in the config + h_directionless = h_preact + h_direction = encode_direction(position, direction=direction) + h = torch.cat([h, h_direction], dim=-1) + + h = layer(h) + + h_preact = h + + if i < len(self.mlp) - 1: + h = self.activation(h) + + h_final = h + if h_directionless is None: + h_directionless = h_preact + + activation = self.map_indices_to_keys(h_final) + + if nerf_level == "coarse": + h_density = activation["density_coarse"] + h_channels = activation["nerf_coarse"] + else: + h_density = activation["density_fine"] + h_channels = activation["nerf_fine"] + + density = self.density_activation(h_density) + signed_distance = self.sdf_activation(activation["sdf"]) + channels = self.channel_activation(h_channels) + + # yiyi notes: I think signed_distance is not used + return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) + + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + vectors: int, + channels: int, + d_latent: int, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels) + self.norm = nn.LayerNorm(channels) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + h = self.norm(h) + + h = h + b_vc + return h + + +class ShapEParamsProjModel(ModelMixin, ConfigMixin): + """ + project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). + + For more details, see the original paper: + """ + + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + ): + super().__init__() + + # check inputs + if len(param_names) != len(param_shapes): + raise ValueError("Must provide same number of `param_names` as `param_shapes`") + self.projections = nn.ModuleDict({}) + for k, (vectors, channels) in zip(param_names, param_shapes): + self.projections[_sanitize_name(k)] = ChannelsProj( + vectors=vectors, + channels=channels, + d_latent=d_latent, + ) + + def forward(self, x: torch.Tensor): + out = {} + start = 0 + for k, shape in zip(self.config.param_names, self.config.param_shapes): + vectors, _ = shape + end = start + vectors + x_bvd = x[:, start:end] + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + + +class ShapERenderer(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + background: Tuple[float] = ( + 255.0, + 255.0, + 255.0, + ), + ): + super().__init__() + + self.params_proj = ShapEParamsProjModel( + param_names=param_names, + param_shapes=param_shapes, + d_latent=d_latent, + ) + self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) + self.void = VoidNeRFModel(background=background, channel_scale=255.0) + self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) + + @torch.no_grad() + def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): + """ + Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below + with some abuse of notations) + + C(r) := sum( + transmittance(t[i]) * integrate( + lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], + ) for i in range(len(parts)) + ) + transmittance(t[-1]) * void_model(t[-1]).channels + + where + + 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through + the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are + obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t + where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the + shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and + transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], + math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + + args: + rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: + number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including + + :return: A tuple of + - `channels` + - A importance samplers for additional fine-grained rendering + - raw model output + """ + origin, direction = rays[..., 0, :], rays[..., 1, :] + + # Integrate over [t[i], t[i + 1]] + + # 1 Intersect the rays with the current volume and sample ts to integrate along. + vrange = self.volume.intersect(origin, direction, t0_lower=None) + ts = sampler.sample(vrange.t0, vrange.t1, n_samples) + ts = ts.to(rays.dtype) + + if prev_model_out is not None: + # Append the previous ts now before fprop because previous + # rendering used a different model and we can't reuse the output. + ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values + + batch_size, *_shape, _t0_dim = vrange.t0.shape + _, *ts_shape, _ts_dim = ts.shape + + # 2. Get the points along the ray and query the model + directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) + positions = origin.unsqueeze(-2) + ts * directions + + directions = directions.to(self.mlp.dtype) + positions = positions.to(self.mlp.dtype) + + optional_directions = directions if render_with_direction else None + + model_out = self.mlp( + position=positions, + direction=optional_directions, + ts=ts, + nerf_level="coarse" if prev_model_out is None else "fine", + ) + + # 3. Integrate the model results + channels, weights, transmittance = integrate_samples( + vrange, model_out.ts, model_out.density, model_out.channels + ) + + # 4. Clean up results that do not intersect with the volume. + transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) + channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) + # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + channels = channels + transmittance * self.void(origin) + + weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) + + return channels, weighted_sampler, model_out + + @torch.no_grad() + def decode( + self, + latents, + device, + size: int = 64, + ray_batch_size: int = 4096, + n_coarse_samples=64, + n_fine_samples=128, + ): + # project the the paramters from the generated latents + projected_params = self.params_proj(latents) + + # update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # create cameras object + camera = create_pan_cameras(size) + rays = camera.camera_rays + rays = rays.to(device) + n_batches = rays.shape[1] // ray_batch_size + + coarse_sampler = StratifiedRaySampler() + + images = [] + + for idx in range(n_batches): + rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] + + # render rays with coarse, stratified samples. + _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) + # Then, render with additional importance-weighted ray samples. + channels, _, _ = self.render_rays( + rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out + ) + + images.append(channels) + + images = torch.cat(images, dim=1) + images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) + + return images diff --git a/diffusers/pipelines/spectrogram_diffusion/__init__.py b/diffusers/pipelines/spectrogram_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05b14a857630e7a7c001a8ae4c23772dfc62a08a --- /dev/null +++ b/diffusers/pipelines/spectrogram_diffusion/__init__.py @@ -0,0 +1,26 @@ +# flake8: noqa +from ...utils import is_note_seq_available, is_transformers_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .notes_encoder import SpectrogramNotesEncoder + from .continous_encoder import SpectrogramContEncoder + from .pipeline_spectrogram_diffusion import ( + SpectrogramContEncoder, + SpectrogramDiffusionPipeline, + T5FilmDecoder, + ) + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 +else: + from .midi_utils import MidiProcessor diff --git a/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py b/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..556136d4023df32e4df2477523463829a0722db4 --- /dev/null +++ b/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py @@ -0,0 +1,92 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import ( + T5Block, + T5Config, + T5LayerNorm, +) + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + input_dims: int, + targets_context_length: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.input_proj = nn.Linear(input_dims, d_model, bias=False) + + self.position_encoding = nn.Embedding(targets_context_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + feed_forward_proj=feed_forward_proj, + dropout_rate=dropout_rate, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_inputs, encoder_inputs_mask): + x = self.input_proj(encoder_inputs) + + # terminal relative positional encodings + max_positions = encoder_inputs.shape[1] + input_positions = torch.arange(max_positions, device=encoder_inputs.device) + + seq_lens = encoder_inputs_mask.sum(-1) + input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) + x += self.position_encoding(input_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_inputs.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/diffusers/pipelines/spectrogram_diffusion/midi_utils.py b/diffusers/pipelines/spectrogram_diffusion/midi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08d0878db588aa38a2e602a3bc5f6505b9457575 --- /dev/null +++ b/diffusers/pipelines/spectrogram_diffusion/midi_utils.py @@ -0,0 +1,667 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import math +import os +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ...utils import is_note_seq_available +from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH + + +if is_note_seq_available(): + import note_seq +else: + raise ImportError("Please install note-seq via `pip install note-seq`") + + +INPUT_FEATURE_LENGTH = 2048 + +SAMPLE_RATE = 16000 +HOP_SIZE = 320 +FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) + +DEFAULT_STEPS_PER_SECOND = 100 +DEFAULT_MAX_SHIFT_SECONDS = 10 +DEFAULT_NUM_VELOCITY_BINS = 1 + +SLAKH_CLASS_PROGRAMS = { + "Acoustic Piano": 0, + "Electric Piano": 4, + "Chromatic Percussion": 8, + "Organ": 16, + "Acoustic Guitar": 24, + "Clean Electric Guitar": 26, + "Distorted Electric Guitar": 29, + "Acoustic Bass": 32, + "Electric Bass": 33, + "Violin": 40, + "Viola": 41, + "Cello": 42, + "Contrabass": 43, + "Orchestral Harp": 46, + "Timpani": 47, + "String Ensemble": 48, + "Synth Strings": 50, + "Choir and Voice": 52, + "Orchestral Hit": 55, + "Trumpet": 56, + "Trombone": 57, + "Tuba": 58, + "French Horn": 60, + "Brass Section": 61, + "Soprano/Alto Sax": 64, + "Tenor Sax": 66, + "Baritone Sax": 67, + "Oboe": 68, + "English Horn": 69, + "Bassoon": 70, + "Clarinet": 71, + "Pipe": 73, + "Synth Lead": 80, + "Synth Pad": 88, +} + + +@dataclasses.dataclass +class NoteRepresentationConfig: + """Configuration note representations.""" + + onsets_only: bool + include_ties: bool + + +@dataclasses.dataclass +class NoteEventData: + pitch: int + velocity: Optional[int] = None + program: Optional[int] = None + is_drum: Optional[bool] = None + instrument: Optional[int] = None + + +@dataclasses.dataclass +class NoteEncodingState: + """Encoding state for note transcription, keeping track of active pitches.""" + + # velocity bin for active pitches and programs + active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: str + value: int + + +class Tokenizer: + def __init__(self, regular_ids: int): + # The special tokens: 0=PAD, 1=EOS, and 2=UNK + self._num_special_tokens = 3 + self._num_regular_tokens = regular_ids + + def encode(self, token_ids): + encoded = [] + for token_id in token_ids: + if not 0 <= token_id < self._num_regular_tokens: + raise ValueError( + f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" + ) + encoded.append(token_id + self._num_special_tokens) + + # Add EOS token + encoded.append(1) + + # Pad to till INPUT_FEATURE_LENGTH + encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) + + return encoded + + +class Codec: + """Encode and decode events. + + Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from + Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not + include things like EOS or UNK token handling. + + To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required + and specified separately. + """ + + def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): + """Define Codec. + + Args: + max_shift_steps: Maximum number of shift steps that can be encoded. + steps_per_second: Shift steps will be interpreted as having a duration of + 1 / steps_per_second. + event_ranges: Other supported event types and their ranges. + """ + self.steps_per_second = steps_per_second + self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) + self._event_ranges = [self._shift_range] + event_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) + + @property + def num_classes(self) -> int: + return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + + # The next couple methods are simplified special case methods just for shift + # events that are intended to be used from within autograph functions. + + def is_shift_event_index(self, index: int) -> bool: + return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) + + @property + def max_shift_steps(self) -> int: + return self._shift_range.max_value + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + offset = 0 + for er in self._event_ranges: + if event.type == er.type: + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f"Event value {event.value} is not within valid range " + f"[{er.min_value}, {er.max_value}] for type {event.type}" + ) + return offset + event.value - er.min_value + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event.type}") + + def event_type_range(self, event_type: str) -> Tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + offset = 0 + for er in self._event_ranges: + if event_type == er.type: + return offset, offset + (er.max_value - er.min_value) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event_type}") + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + offset = 0 + for er in self._event_ranges: + if offset <= index <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + index - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event index: {index}") + + +@dataclasses.dataclass +class ProgramGranularity: + # both tokens_map_fn and program_map_fn should be idempotent + tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] + program_map_fn: Callable[[int], int] + + +def drop_programs(tokens, codec: Codec): + """Drops program change events from a token sequence.""" + min_program_id, max_program_id = codec.event_type_range("program") + return tokens[(tokens < min_program_id) | (tokens > max_program_id)] + + +def programs_to_midi_classes(tokens, codec): + """Modifies program events to be the first program in the MIDI class.""" + min_program_id, max_program_id = codec.event_type_range("program") + is_program = (tokens >= min_program_id) & (tokens <= max_program_id) + return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) + + +PROGRAM_GRANULARITIES = { + # "flat" granularity; drop program change tokens and set NoteSequence + # programs to zero + "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), + # map each program to the first program in its MIDI class + "midi_class": ProgramGranularity( + tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) + ), + # leave programs as is + "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), +} + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): + """ + equivalent of tf.signal.frame + """ + signal_length = signal.shape[axis] + if pad_end: + frames_overlap = frame_length - frame_step + rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) + pad_size = int(frame_length - rest_samples) + + if pad_size != 0: + pad_axis = [0] * signal.ndim + pad_axis[axis] = pad_size + signal = F.pad(signal, pad_axis, "constant", pad_value) + frames = signal.unfold(axis, frame_length, frame_step) + return frames + + +def program_to_slakh_program(program): + # this is done very hackily, probably should use a custom mapping + for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): + if program >= slakh_program: + return slakh_program + + +def audio_to_frames( + samples, + hop_size: int, + frame_rate: int, +) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: + """Convert audio samples to non-overlapping frames and frame times.""" + frame_size = hop_size + samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") + + # Split audio into frames. + frames = frame( + torch.Tensor(samples).unsqueeze(0), + frame_length=frame_size, + frame_step=frame_size, + pad_end=False, # TODO check why its off by 1 here when True + ) + + num_frames = len(samples) // frame_size + + times = np.arange(num_frames) / frame_rate + return frames, times + + +def note_sequence_to_onsets_and_offsets_and_programs( + ns: note_seq.NoteSequence, +) -> Tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches & programs from a NoteSequence. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for + note + offsets. + """ + # Sort by program and pitch and put offsets before onsets as a tiebreaker for + # subsequent stable sort. + notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) + times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] + values = [ + NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) + for note in notes + if not note.is_drum + ] + [ + NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) + for note in notes + ] + return times, values + + +def num_velocity_bins_from_codec(codec: Codec): + """Get number of velocity bins from event codec.""" + lo, hi = codec.event_type_range("velocity") + return hi - lo + + +# segment an array into segments of length n +def segment(a, n): + return [a[i : i + n] for i in range(0, len(a), n)] + + +def velocity_to_bin(velocity, num_velocity_bins): + if velocity == 0: + return 0 + else: + return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) + + +def note_event_data_to_events( + state: Optional[NoteEncodingState], + value: NoteEventData, + codec: Codec, +) -> Sequence[Event]: + """Convert note event data to a sequence of events.""" + if value.velocity is None: + # onsets only, no program or velocity + return [Event("pitch", value.pitch)] + else: + num_velocity_bins = num_velocity_bins_from_codec(codec) + velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) + if value.program is None: + # onsets + offsets + velocities only, no programs + if state is not None: + state.active_pitches[(value.pitch, 0)] = velocity_bin + return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] + else: + if value.is_drum: + # drum events use a separate vocabulary + return [Event("velocity", velocity_bin), Event("drum", value.pitch)] + else: + # program + velocity + pitch + if state is not None: + state.active_pitches[(value.pitch, value.program)] = velocity_bin + return [ + Event("program", value.program), + Event("velocity", velocity_bin), + Event("pitch", value.pitch), + ] + + +def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: + """Output program and pitch events for active notes plus a final tie event.""" + events = [] + for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): + if state.active_pitches[(pitch, program)]: + events += [Event("program", program), Event("pitch", pitch)] + events.append(Event("tie", 0)) + return events + + +def encode_and_index_events( + state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None +): + """Encode a sequence of timed events and index to audio frame times. + + Encodes time shifts as repeated single step shifts for later run length encoding. + + Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio + frame. This can be used e.g. to prepend events representing the current state to a targets segment. + + Args: + state: Initial event encoding state. + event_times: Sequence of event times. + event_values: Sequence of event values. + encode_event_fn: Function that transforms event value into a sequence of one + or more Event objects. + codec: An Codec object that maps Event objects to indices. + frame_times: Time for every audio frame. + encoding_state_to_events_fn: Function that transforms encoding state into a + sequence of one or more Event objects. + + Returns: + events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. + Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes + splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of + another. + event_end_indices: Corresponding end event index for every audio frame. Used + to ensure when slicing that one chunk ends where the next begins. Should always be true that + event_end_indices[i] = event_start_indices[i + 1]. + state_events: Encoded "state" events representing the encoding state before + each event. + state_event_indices: Corresponding state event index for every audio frame. + """ + indices = np.argsort(event_times, kind="stable") + event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] + event_values = [event_values[i] for i in indices] + + events = [] + state_events = [] + event_start_indices = [] + state_event_indices = [] + + cur_step = 0 + cur_event_idx = 0 + cur_state_event_idx = 0 + + def fill_event_start_indices_to_cur_step(): + while ( + len(event_start_indices) < len(frame_times) + and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second + ): + event_start_indices.append(cur_event_idx) + state_event_indices.append(cur_state_event_idx) + + for event_step, event_value in zip(event_steps, event_values): + while event_step > cur_step: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + cur_state_event_idx = len(state_events) + if encoding_state_to_events_fn: + # Dump state to state events *before* processing the next event, because + # we want to capture the state prior to the occurrence of the event. + for e in encoding_state_to_events_fn(state): + state_events.append(codec.encode_event(e)) + + for e in encode_event_fn(state, event_value, codec): + events.append(codec.encode_event(e)) + + # After the last event, continue filling out the event_start_indices array. + # The inequality is not strict because if our current step lines up exactly + # with (the start of) an audio frame, we need to add an additional shift event + # to "cover" that frame. + while cur_step / codec.steps_per_second <= frame_times[-1]: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + + # Now fill in event_end_indices. We need this extra array to make sure that + # when we slice events, each slice ends exactly where the subsequent slice + # begins. + event_end_indices = event_start_indices[1:] + [len(events)] + + events = np.array(events).astype(np.int32) + state_events = np.array(state_events).astype(np.int32) + event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + + outputs = [] + for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): + outputs.append( + { + "inputs": events, + "event_start_indices": start_indices, + "event_end_indices": end_indices, + "state_events": state_events, + "state_event_indices": event_indices, + } + ) + + return outputs + + +def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): + """Extract target sequence corresponding to audio token segment.""" + features = features.copy() + start_idx = features["event_start_indices"][0] + end_idx = features["event_end_indices"][-1] + + features[feature_key] = features[feature_key][start_idx:end_idx] + + if state_events_end_token is not None: + # Extract the state events corresponding to the audio start token, and + # prepend them to the targets array. + state_event_start_idx = features["state_event_indices"][0] + state_event_end_idx = state_event_start_idx + 1 + while features["state_events"][state_event_end_idx - 1] != state_events_end_token: + state_event_end_idx += 1 + features[feature_key] = np.concatenate( + [ + features["state_events"][state_event_start_idx:state_event_end_idx], + features[feature_key], + ], + axis=0, + ) + + return features + + +def map_midi_programs( + feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" +) -> Mapping[str, Any]: + """Apply MIDI program map to token sequences.""" + granularity = PROGRAM_GRANULARITIES[granularity_type] + + feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) + return feature + + +def run_length_encode_shifts_fn( + features, + codec: Codec, + feature_key: str = "inputs", + state_change_event_types: Sequence[str] = (), +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return a function that run-length encodes shifts for a given codec. + + Args: + codec: The Codec to use for shift events. + feature_key: The feature key for which to run-length encode shifts. + state_change_event_types: A list of event types that represent state + changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones + will be removed. + + Returns: + A preprocessing function that run-length encodes single-step shifts. + """ + state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] + + def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: + """Combine leading/interior shifts, trim trailing shifts. + + Args: + features: Dict of features to process. + + Returns: + A dict of features. + """ + events = features[feature_key] + + shift_steps = 0 + total_shift_steps = 0 + output = np.array([], dtype=np.int32) + + current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) + + for event in events: + if codec.is_shift_event_index(event): + shift_steps += 1 + total_shift_steps += 1 + + else: + # If this event is a state change and has the same value as the current + # state, we can skip it entirely. + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state[i] = event + if is_redundant: + continue + + # Once we've reached a non-shift event, RLE all previous shift events + # before outputting the non-shift event. + if shift_steps > 0: + shift_steps = total_shift_steps + while shift_steps > 0: + output_steps = np.minimum(codec.max_shift_steps, shift_steps) + output = np.concatenate([output, [output_steps]], axis=0) + shift_steps -= output_steps + output = np.concatenate([output, [event]], axis=0) + + features[feature_key] = output + return features + + return run_length_encode_shifts(features) + + +def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): + tie_token = codec.encode_event(Event("tie", 0)) + state_events_end_token = tie_token if note_representation_config.include_ties else None + + features = extract_sequence_with_indices( + features, state_events_end_token=state_events_end_token, feature_key="inputs" + ) + + features = map_midi_programs(features, codec) + + features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) + + return features + + +class MidiProcessor: + def __init__(self): + self.codec = Codec( + max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, + steps_per_second=DEFAULT_STEPS_PER_SECOND, + event_ranges=[ + EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), + EventRange("tie", 0, 0), + EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), + EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + ], + ) + self.tokenizer = Tokenizer(self.codec.num_classes) + self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) + + def __call__(self, midi: Union[bytes, os.PathLike, str]): + if not isinstance(midi, bytes): + with open(midi, "rb") as f: + midi = f.read() + + ns = note_seq.midi_to_note_sequence(midi) + ns_sus = note_seq.apply_sustain_control_changes(ns) + + for note in ns_sus.notes: + if not note.is_drum: + note.program = program_to_slakh_program(note.program) + + samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) + + _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) + times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) + + events = encode_and_index_events( + state=NoteEncodingState(), + event_times=times, + event_values=values, + frame_times=frame_times, + codec=self.codec, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=note_encoding_state_to_events, + ) + + events = [ + note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events + ] + input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] + + return input_tokens diff --git a/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py b/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..94eaa176f3e5a15f4065e78b4b7714fa8c51ca83 --- /dev/null +++ b/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py @@ -0,0 +1,86 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + max_length: int, + vocab_size: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.token_embedder = nn.Embedding(vocab_size, d_model) + + self.position_encoding = nn.Embedding(max_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + vocab_size=vocab_size, + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + dropout_rate=dropout_rate, + feed_forward_proj=feed_forward_proj, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_input_tokens, encoder_inputs_mask): + x = self.token_embedder(encoder_input_tokens) + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) + x += self.position_encoding(inputs_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_input_tokens.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..66155ebf7f35cbe224bf21fd54c47f3b5ee32a37 --- /dev/null +++ b/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py @@ -0,0 +1,210 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...models import T5FilmDecoder +from ...schedulers import DDPMScheduler +from ...utils import is_onnx_available, logging, randn_tensor + + +if is_onnx_available(): + from ..onnx_utils import OnnxRuntimeModel + +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .continous_encoder import SpectrogramContEncoder +from .notes_encoder import SpectrogramNotesEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TARGET_FEATURE_LENGTH = 256 + + +class SpectrogramDiffusionPipeline(DiffusionPipeline): + _optional_components = ["melgan"] + + def __init__( + self, + notes_encoder: SpectrogramNotesEncoder, + continuous_encoder: SpectrogramContEncoder, + decoder: T5FilmDecoder, + scheduler: DDPMScheduler, + melgan: OnnxRuntimeModel if is_onnx_available() else Any, + ) -> None: + super().__init__() + + # From MELGAN + self.min_value = math.log(1e-5) # Matches MelGAN training. + self.max_value = 4.0 # Largest value for most examples + self.n_dims = 128 + + self.register_modules( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + + def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): + """Linearly scale features to network outputs range.""" + min_out, max_out = output_range + if clip: + features = torch.clip(features, self.min_value, self.max_value) + # Scale to [0, 1]. + zero_one = (features - self.min_value) / (self.max_value - self.min_value) + # Scale to [min_out, max_out]. + return zero_one * (max_out - min_out) + min_out + + def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): + """Invert by linearly scaling network outputs to features range.""" + min_out, max_out = input_range + outputs = torch.clip(outputs, min_out, max_out) if clip else outputs + # Scale to [0, 1]. + zero_one = (outputs - min_out) / (max_out - min_out) + # Scale to [self.min_value, self.max_value]. + return zero_one * (self.max_value - self.min_value) + self.min_value + + def encode(self, input_tokens, continuous_inputs, continuous_mask): + tokens_mask = input_tokens > 0 + tokens_encoded, tokens_mask = self.notes_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time): + timesteps = noise_time + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(input_tokens.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + logits = self.decoder( + encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps + ) + return logits + + @torch.no_grad() + def __call__( + self, + input_tokens: List[List[int]], + generator: Optional[torch.Generator] = None, + num_inference_steps: int = 100, + return_dict: bool = True, + output_type: str = "numpy", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ) -> Union[AudioPipelineOutput, Tuple]: + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) + full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) + ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + + for i, encoder_input_tokens in enumerate(input_tokens): + if i == 0: + encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( + device=self.device, dtype=self.decoder.dtype + ) + # The first chunk has no previous context. + encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + else: + # The full song pipeline does not feed in a context feature, so the mask + # will be all 0s after the feature converter. Because we know we're + # feeding in a full context chunk from the previous prediction, set it + # to all 1s. + encoder_continuous_mask = ones + + encoder_continuous_inputs = self.scale_features( + encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True + ) + + encodings_and_masks = self.encode( + input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + ) + + # Sample encoder_continuous_inputs shaped gaussian noise to begin loop + x = randn_tensor( + shape=encoder_continuous_inputs.shape, + generator=generator, + device=self.device, + dtype=self.decoder.dtype, + ) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + # Denoising diffusion loop + for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + output = self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=x, + noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) + ) + + # Compute previous output: x_t -> x_t-1 + x = self.scheduler.step(output, t, x, generator=generator).prev_sample + + mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) + encoder_continuous_inputs = mel[:1] + pred_mel = mel.cpu().float().numpy() + + full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, full_pred_mel) + + logger.info("Generated segment", i) + + if output_type == "numpy" and not is_onnx_available(): + raise ValueError( + "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." + ) + elif output_type == "numpy" and self.melgan is None: + raise ValueError( + "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." + ) + + if output_type == "numpy": + output = self.melgan(input_features=full_pred_mel.astype(np.float32)) + else: + output = full_pred_mel + + if not return_dict: + return (output,) + + return AudioPipelineOutput(audios=output) diff --git a/diffusers/pipelines/stable_diffusion/README.md b/diffusers/pipelines/stable_diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..66df9a811afbf70a5e943ed1a1e3e7c6955e6c25 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/README.md @@ -0,0 +1,176 @@ +# Stable Diffusion + +## Overview + +Stable Diffusion was proposed in [Stable Diffusion Announcement](https://stability.ai/blog/stable-diffusion-announcement) by Patrick Esser and Robin Rombach and the Stability AI team. + +The summary of the model is the following: + +*Stable Diffusion is a text-to-image model that will empower billions of people to create stunning art within seconds. It is a breakthrough in speed and quality meaning that it can run on consumer GPUs. You can see some of the amazing output that has been created by this model without pre or post-processing on this page. The model itself builds upon the work of the team at CompVis and Runway in their widely used latent diffusion model combined with insights from the conditional diffusion models by our lead generative AI developer Katherine Crowson, Dall-E 2 by Open AI, Imagen by Google Brain and many others. We are delighted that AI media generation is a cooperative field and hope it can continue this way to bring the gift of creativity to all.* + +## Tips: + +- Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model. +- An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion). +- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can +download the weights with `git lfs install; git clone https://huggingface.co/runwayml/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below. +- Stable Diffusion can work with a variety of different samplers as is shown below. + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) + +## Examples: + +### Using Stable Diffusion without being logged into the Hub. + +If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`. + +```python +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +``` + +This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-5"`: + +``` +git lfs install +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 +``` + +and simply passing the local path to `from_pretrained`: + +```python +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") +``` + +### Text-to-Image with default PLMS scheduler + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### Text-to-Image with DDIM scheduler + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline, DDIMScheduler + +scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + scheduler=scheduler, +).to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### Text-to-Image with K-LMS scheduler + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler + +lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + scheduler=lms, +).to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +image.save("astronaut_rides_horse.png") +``` + +### CycleDiffusion using Stable Diffusion and DDIM scheduler + +```python +import requests +import torch +from PIL import Image +from io import BytesIO + +from diffusers import CycleDiffusionPipeline, DDIMScheduler + + +# load the scheduler. CycleDiffusion only supports stochastic schedulers. + +# load the pipeline +# make sure you're logged in with `huggingface-cli login` +model_id_or_path = "CompVis/stable-diffusion-v1-4" +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") +pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + +# let's download an initial image +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("horse.png") + +# let's specify a prompt +source_prompt = "An astronaut riding a horse" +prompt = "An astronaut riding an elephant" + +# call the pipeline +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, +).images[0] + +image.save("horse_to_elephant.png") + +# let's try another example +# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion +url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) +init_image.save("black.png") + +source_prompt = "A black colored car" +prompt = "A blue colored car" + +# call the pipeline +torch.manual_seed(0) +image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, +).images[0] + +image.save("black_to_blue.png") +``` diff --git a/diffusers/pipelines/stable_diffusion/__init__.py b/diffusers/pipelines/stable_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33ab05a1dacbdfdfc02966675de4c30cb1069a10 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + BaseOutput, + OptionalDependencyNotAvailable, + is_flax_available, + is_k_diffusion_available, + is_k_diffusion_version, + is_onnx_available, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_cycle_diffusion import CycleDiffusionPipeline + from .pipeline_stable_diffusion import StableDiffusionPipeline + from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline + from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy + from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline + from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline + from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline + from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline + from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline + from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline + from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline + from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline + from .pipeline_stable_unclip import StableUnCLIPPipeline + from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline + from .safety_checker import StableDiffusionSafetyChecker + from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline +else: + from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline + + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionPix2PixZeroPipeline, + ) +else: + from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline + + +try: + if not ( + is_torch_available() + and is_transformers_available() + and is_k_diffusion_available() + and is_k_diffusion_version(">=", "0.0.12") + ): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 +else: + from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline + +try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_onnx_objects import * # noqa F403 +else: + from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline + from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline + from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline + from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy + from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline + +if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`np.ndarray`) + Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: np.ndarray + nsfw_content_detected: List[bool] + + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline + from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe8a5e86fcb9abc1e5bd24979a297e5f4a61fa0 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be6e5dbc542980a2a9c181d51c0d121b8c5daaad Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1efdce6f5f17dd9e96e237838121bcd20901ea2 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..310f483c52a4c2b5764a000bd38fc64159054a9e Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23cdaf1ea18d26d34e21833b6a4613ab3da8d560 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e954a2539683f45e4550487277122d0c8d1016d6 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..740e970b3f04f06fa677d743bdab640888b6a15a Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..265fccce7fc35213d4534435fe2970ae7ebbcf80 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b42d5b10821b289f8d5c5bbf62ec1a2d63f65bc9 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a56995c159ce88c7b6dabd94a89ecb2f0e464b9 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f7c973124b9ebbf8caaf31dd0286aacd58a984 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e957f1acc09ad47c59f56f41373a75ac2552cc1 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469564dbfc6d3d22c9335406beefafae06b100ca Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..174e4b0511c6a12cdeaab0cc3c662f2a88880958 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd2779536ca92718804bbaee0682a90e59e78aa9 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..059cb08407edd44cc9e259a560ae464a3a718d84 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a3598f0188cd4a18e09a66212ca3ec4e78d5e42 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56cc8a71f4ba495052576d8c799e287cee41284e Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f8a7a60eb8165ba4c449fbaa9375a865145d15 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a87881071f44f32b23e4b83469b180a190ad8a Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b79e5fe13c882f24ab56d5d9a3e7bb0956e427e Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f902de3b0e0f270d96292fd0fd2ed9c178cef09 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b817d0eac72a93b73a2a3428c998c55f9942c97d Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92e8e5a1a31f0b84ac17f2f1d01414c732cd490d Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..054340af5112103fcc31a446100f38def1963177 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41744f8c3f44e67fb16d9f66e8e2b1e6638f128f Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b0af5da41ff585583cae35582084ce525ff9ac Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d6e5409c2d52c75d344b7fb9e650a0e8688e751 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b7cc10ebbdcfe5dec86e1e015d6e5a9a6ae857 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dee35d3c90bf3c2122fa9a10ecc0429453f1bd5f Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86a259b57ad5f316a50655c02f5345fe3a988736 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac152a1991c3a55335aef4ee3332affc76d0a87f Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9d83f26ffeb9e48f89b8e3a8560bf49cbdcade4 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78aa983eef72b178bc75eb7748e8ff2505e04cdc Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82865fe2942026cdd3e092f2d423318bc8c533ed Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..062381426a02ffc57f12661de692aeed3ced1379 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffee4e5c133f3bbb624db8ee208858ea07c20bd3 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f63e62ca11890edd63dd1427644a6164a2119b55 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6e99169784cd2d8c4c1ce7e249b02ef72418612 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b4297d4174ad8a978699b8413e6c2868b73dd3 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efefb76477b66d6756abb0ee9c2401ae5d21ef29 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8e3d20e59e4ad36c04dd8efb8117689263dd3d Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d538ed027c556579fb4e02653da477bfe2c98ecc Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4790747846014eb8c3c861c015f91b8644ea3de3 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb27ea686000dc76cdae7d11216b032ac24193e Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ed3f0730f15935eefe5178d612d7b5f32016f8 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb913e81c25ea5814215460d4fd8923636f7767e Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..0eeb80f12dfc9a10f0479ce711ee85869d839dd5 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -0,0 +1,1636 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from typing import Optional + +import requests +import torch +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from ...models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging +from ...utils.import_utils import BACKENDS_MAPPING +from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from ..paint_by_example import PaintByExampleImageEncoder +from ..pipeline_utils import DiffusionPipeline +from .safety_checker import StableDiffusionSafetyChecker +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + config = CLIPTextConfig.from_pretrained(config_name) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path: str, + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + load_safety_checker: bool = True, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + text_encoder=None, + tokenizer=None, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if pipeline_class is None: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if not is_omegaconf_available(): + raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) + + from omegaconf import OmegaConf + + if from_safetensors: + if not is_safetensors_available(): + raise ValueError(BACKENDS_MAPPING["safetensors"][1]) + + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path, device="cpu") + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + + # model_type = "v1" + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + original_config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(original_config_file) + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: + num_in_channels = 9 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config.model.params: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None: + controlnet = "control_stage_config" in original_config.model.params + + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) + tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") + + prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + else: + safety_checker = None + feature_extractor = None + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") + + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs + ) + + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + else: + tokenizer = None + text_encoder = None + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") + + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLImg2ImgPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + requires_aesthetics_score=True, + force_zeros_for_empty_prompt=False, + ) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, +) -> DiffusionPipeline: + if not is_omegaconf_available(): + raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) + + from omegaconf import OmegaConf + + if from_safetensors: + if not is_safetensors_available(): + raise ValueError(BACKENDS_MAPPING["safetensors"][1]) + + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config.model.params: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet diff --git a/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9a68c4d059c6ac4180cf5d0556b4a9d380497213 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -0,0 +1,796 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.utils import is_accelerate_available, is_accelerate_version + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + if prev_timestep <= 0: + return clean_latents + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # direction pointing to x_t + e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) + dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t + noise = std_dev_t * randn_tensor( + clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator + ) + prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise + + return prev_latents + + +def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( + variance ** (0.5) * eta + ) + return noise + + +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + image = image.to(device=device, dtype=dtype) + + batch_size = image.shape[0] + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timestep + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + clean_latents = init_latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents, clean_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + source_prompt: Union[str, List[str]], + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + source_guidance_scale: Optional[float] = 1, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + source_guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale for the source prompt. This is useful to control the amount of influence the source + prompt for encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.1): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + source_prompt_embeds = self._encode_prompt( + source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, clean_latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + source_latents = latents + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + generator = extra_step_kwargs.pop("generator", None) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + source_latent_model_input = torch.cat([source_latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + + # predict the noise residual + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_prompt_embeds = torch.stack( + [ + source_prompt_embeds[0], + prompt_embeds[0], + source_prompt_embeds[1], + prompt_embeds[1], + ], + dim=0, + ) + concat_noise_pred = self.unet( + concat_latent_model_input, + t, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=concat_prompt_embeds, + ).sample + + # perform guidance + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) + + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4f77029ce45497abea4807e97dc8656aaa6a99 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,470 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import deprecate, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + + >>> from diffusers import FlaxStableDiffusionPipeline + + >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16 + ... ) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> prompt_ids = pipeline.prepare_inputs(prompt) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + + >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 4, 5, 6), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..bec2424ece4dc91fbafd530d525e36d1fb84c4ff --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -0,0 +1,28 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works + +from ...utils import deprecate +from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401 + + +deprecate( + "stable diffusion controlnet", + "0.22.0", + "Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.", + standard_warn=False, + stacklevel=3, +) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..6a387af364b7467a9f88d537071a48e001f99b69 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -0,0 +1,527 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_img = init_img.resize((768, 512)) + + >>> prompts = "A fantasy landscape, trending on artstation" + + >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... revision="flax", + ... dtype=jnp.bfloat16, + ... ) + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + >>> prompt_ids, processed_image = pipeline.prepare_inputs( + ... prompt=[prompts] * num_samples, image=[init_img] * num_samples + ... ) + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipeline( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... strength=0.75, + ... num_inference_steps=50, + ... jit=True, + ... height=512, + ... width=768, + ... ).images + + >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for image-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def get_timestep_start(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + return t_start + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + start_timestep: int, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + noise: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if noise is None: + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if noise.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") + + # Create init_latents + init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) + init_latents = self.vae.config.scaling_factor * init_latents + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) + + latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(start_timestep, num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + strength: float = 0.8, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.array] = 7.5, + noise: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing an image batch, that will be used as the starting point for the process. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + noise (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + start_timestep = self.get_timestep_start(num_inference_steps, strength) + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 5, 6, 7, 8), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..abb57f8b62e9aab62b7dc83329ab2a3c1f623532 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -0,0 +1,580 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import PIL + >>> import requests + >>> from io import BytesIO + >>> from diffusers import FlaxStableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( + ... "xvjiarui/stable-diffusion-2-inpainting" + ... ) + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> init_image = num_samples * [init_image] + >>> mask_image = num_samples * [mask_image] + >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( + ... prompt, init_image, mask_image + ... ) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + >>> processed_masked_images = shard(processed_masked_images) + >>> processed_masks = shard(processed_masks) + + >>> images = pipeline( + ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True + ... ).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs( + self, + prompt: Union[str, List[str]], + image: Union[Image.Image, List[Image.Image]], + mask: Union[Image.Image, List[Image.Image]], + ): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + if not isinstance(mask, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(mask, Image.Image): + mask = [mask] + + processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) + processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) + # processed_masks[processed_masks < 0.5] = 0 + processed_masks = processed_masks.at[processed_masks < 0.5].set(0) + # processed_masks[processed_masks >= 0.5] = 1 + processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) + + processed_masked_images = processed_images * (processed_masks < 0.5) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_masked_images, processed_masks + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + mask: jnp.array, + masked_image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.vae.config.latent_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + prng_seed, mask_prng_seed = jax.random.split(prng_seed) + + masked_image_latent_dist = self.vae.apply( + {"params": params["vae"]}, masked_image, method=self.vae.encode + ).latent_dist + masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + del mask_prng_seed + + mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") + + # 8. Check that sizes of mask, masked image and latents match + num_channels_latents = self.vae.config.latent_channels + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + def loop_body(step, args): + latents, mask, masked_image_latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + mask_input = jnp.concatenate([mask] * 2) + masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + # concat latents, mask, masked_image_latents in the channel dimension + latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, mask, masked_image_latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, mask, masked_image_latents, scheduler_state = loop_body( + i, (latents, mask, masked_image_latents, scheduler_state) + ) + else: + latents, _, _, _ = jax.lax.fori_loop( + 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) + ) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + mask: jnp.array, + masked_image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") + mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 6, 7, 8), +) +def _p_generate( + pipe, + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess_image(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, dtype): + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w, h)) + mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 + mask = jnp.expand_dims(mask, axis=(0, 1)) + + return mask diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..eb02f6cb321cb02ec5bd7badc0f6c73f06ae1e41 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -0,0 +1,485 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) + + +class OnnxStableDiffusionPipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + One or a list of [numpy generator(s)](TODO) to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + ): + deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." + deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..293ed7d981b80a30cfad9a4a84478c7209a1cea7 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -0,0 +1,552 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 +def preprocess(image): + warnings.warn( + ( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead" + ), + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: Union[str, List[str]], + callback_steps: int, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + image = preprocess(image).cpu().numpy() + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + if isinstance(prompt, str): + prompt = [prompt] + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb39c4b1c617ea07e71355364f6476f6178e806 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -0,0 +1,560 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +NUM_UNET_INPUT_CHANNELS = 9 +NUM_LATENT_CHANNELS = 4 + + +def prepare_mask_and_masked_image(image, mask, latents_shape): + image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + image = image[None].transpose(0, 3, 1, 2) + image = image.astype(np.float32) / 127.5 - 1.0 + + image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + masked_image = image * (image_mask < 127.5) + + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask, masked_image + + +class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + num_channels_latents = NUM_LATENT_CHANNELS + latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_dtype = prompt_embeds.dtype + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) + mask = mask.astype(latents.dtype) + masked_image = masked_image.astype(latents.dtype) + + masked_image_latents = self.vae_encoder(sample=masked_image)[0] + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt + mask = mask.repeat(batch_size * num_images_per_prompt, 0) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) + + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + + unet_input_channels = NUM_UNET_INPUT_CHANNELS + if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef7a781451c2757e5657aba9c1ff24276890524 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,539 @@ +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to + provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[np.ndarray, PIL.Image.Image] = None, + mask_image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, np.ndarray): + mask_image = preprocess_mask(mask_image, 8) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + + latents = latents.numpy() + + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) + ) + + init_latents_proper = init_latents_proper.numpy() + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..56681391aeeba7d0146cc4f296e4ead20204c33e --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -0,0 +1,391 @@ +from logging import getLogger +from typing import Any, Callable, List, Optional, Union + +import numpy as np +import PIL +import torch + +from ...schedulers import DDPMScheduler +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import ImagePipelineOutput +from . import StableDiffusionUpscalePipeline + + +logger = getLogger(__name__) + + +NUM_LATENT_CHANNELS = 4 +NUM_UNET_INPUT_CHANNELS = 7 + +ORT_TO_PT_TYPE = { + "float16": torch.float16, + "float32": torch.float32, +} + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + return image + + +class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): + def __init__( + self, + vae: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: Any, + unet: OnnxRuntimeModel, + low_res_scheduler: DDPMScheduler, + scheduler: Any, + max_noise_level: int = 350, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + watermarker=None, + max_noise_level=max_noise_level, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + noise_level TODO + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs(prompt, image, noise_level, callback_steps) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] + + # 4. Preprocess image + image = preprocess(image) + image = image.cpu() + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) + noise_level = np.concatenate([noise_level] * image.shape[0]) + + # 6. Prepare latent variables + height, width = image.shape[2:] + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + NUM_LATENT_CHANNELS, + height, + width, + latents_dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +" + f" `num_channels_image`: {num_channels_image} " + f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = np.concatenate([latent_model_input, image], axis=1) + + # timestep to tensor + timestep = np.array([t], dtype=timestep_dtype) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=text_embeddings, + class_labels=noise_level.astype(np.int64), + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents.float()) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def decode_latents(self, latents): + latents = 1 / 0.08333 * latents + image = self.vae(latent_sample=latents)[0] + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + return image + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device, + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + # no positional arguments to text_encoder + prompt_embeds = self.text_encoder( + input_ids=text_input_ids.int().to(device), + # attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + # if hasattr(uncond_input, "attention_mask"): + # attention_mask = uncond_input.attention_mask.to(device) + # else: + # attention_mask = None + + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.int().to(device), + # attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + if do_classifier_free_guidance: + seq_len = uncond_embeddings.shape[1] + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds]) + + return prompt_embeds diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..54927049571cadd73dfc4e2135f48baa34d0011e --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,732 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py new file mode 100644 index 0000000000000000000000000000000000000000..15a5d7eb13624524d944de6d45aa92a208d70ac0 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -0,0 +1,1032 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.nn import functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import Attention +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionAttendAndExcitePipeline + + >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> prompt = "a cat and a frog" + + >>> # use get_indices function to find out indices of the tokens you want to alter + >>> pipe.get_indices(prompt) + {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} + + >>> token_indices = [2, 5] + >>> seed = 6141 + >>> generator = torch.Generator("cuda").manual_seed(seed) + + >>> images = pipe( + ... prompt=prompt, + ... token_indices=token_indices, + ... guidance_scale=7.5, + ... generator=generator, + ... num_inference_steps=50, + ... max_iter_to_alter=25, + ... ).images + + >>> image = images[0] + >>> image.save(f"../images/{prompt}_{seed}.png") + ``` +""" + + +class AttentionStore: + @staticmethod + def get_empty_store(): + return {"down": [], "mid": [], "up": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= 0 and is_cross: + if attn.shape[1] == np.prod(self.attn_res): + self.step_store[place_in_unet].append(attn) + + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.between_steps() + + def between_steps(self): + self.attention_store = self.step_store + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: + """Aggregates the attention across the different layers and heads at the specified resolution.""" + out = [] + attention_maps = self.get_average_attention() + for location in from_where: + for item in attention_maps[location]: + cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + + def reset(self): + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self, attn_res): + """ + Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion + process + """ + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + self.curr_step_index = 0 + self.attn_res = attn_res + + +class AttendExciteAttnProcessor: + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + + # only need to store attention maps during the Attend and Excite process + if attention_probs.requires_grad: + self.attnstore(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + indices, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int) + indices_is_list_list_ints = ( + isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int) + ) + + if not indices_is_list_ints and not indices_is_list_list_ints: + raise TypeError("`indices` must be a list of ints or a list of a list of ints") + + if indices_is_list_ints: + indices_batch_size = 1 + elif indices_is_list_list_ints: + indices_batch_size = len(indices) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if indices_batch_size != prompt_batch_size: + raise ValueError( + f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @staticmethod + def _compute_max_attention_per_index( + attention_maps: torch.Tensor, + indices: List[int], + ) -> List[torch.Tensor]: + """Computes the maximum attention value for each of the tokens we wish to alter.""" + attention_for_text = attention_maps[:, :, 1:-1] + attention_for_text *= 100 + attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) + + # Shift indices since we removed the first token + indices = [index - 1 for index in indices] + + # Extract the maximum values + max_indices_list = [] + for i in indices: + image = attention_for_text[:, :, i] + smoothing = GaussianSmoothing().to(attention_maps.device) + input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") + image = smoothing(input).squeeze(0).squeeze(0) + max_indices_list.append(image.max()) + return max_indices_list + + def _aggregate_and_get_max_attention_per_token( + self, + indices: List[int], + ): + """Aggregates the attention for each token and computes the max activation value for each token to alter.""" + attention_maps = self.attention_store.aggregate_attention( + from_where=("up", "down", "mid"), + ) + max_attention_per_index = self._compute_max_attention_per_index( + attention_maps=attention_maps, + indices=indices, + ) + return max_attention_per_index + + @staticmethod + def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor: + """Computes the attend-and-excite loss using the maximum attention value for each token.""" + losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] + loss = max(losses) + return loss + + @staticmethod + def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: + """Update the latent according to the computed loss.""" + grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] + latents = latents - step_size * grad_cond + return latents + + def _perform_iterative_refinement_step( + self, + latents: torch.Tensor, + indices: List[int], + loss: torch.Tensor, + threshold: float, + text_embeddings: torch.Tensor, + step_size: float, + t: int, + max_refinement_steps: int = 20, + ): + """ + Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code + according to our loss objective until the given threshold is reached for all tokens. + """ + iteration = 0 + target_loss = max(0, 1.0 - threshold) + while loss > target_loss: + iteration += 1 + + latents = latents.clone().detach().requires_grad_(True) + self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=indices, + ) + + loss = self._compute_loss(max_attention_per_index) + + if loss != 0: + latents = self._update_latent(latents, loss, step_size) + + logger.info(f"\t Try {iteration}. loss: {loss}") + + if iteration >= max_refinement_steps: + logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") + break + + # Run one more time but don't compute gradients and update the latents. + # We just need to compute the new loss - the grad update will occur below + latents = latents.clone().detach().requires_grad_(True) + _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=indices, + ) + loss = self._compute_loss(max_attention_per_index) + logger.info(f"\t Finished with loss of: {loss}") + return loss, latents, max_attention_per_index + + def register_attention_control(self): + attn_procs = {} + cross_att_count = 0 + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) + + self.unet.set_attn_processor(attn_procs) + self.attention_store.num_att_layers = cross_att_count + + def get_indices(self, prompt: str) -> Dict[str, int]: + """Utility function to list the indices of the tokens you wish to alte""" + ids = self.tokenizer(prompt).input_ids + indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} + return indices + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + token_indices: Union[List[int], List[List[int]]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + max_iter_to_alter: int = 25, + thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, + scale_factor: int = 20, + attn_res: Optional[Tuple[int]] = (16, 16), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + token_indices (`List[int]`): + The token indices to alter with attend-and-excite. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + max_iter_to_alter (`int`, *optional*, defaults to `25`): + Number of denoising steps to apply attend-and-excite. The first denoising steps are + where the attend-and-excite is applied. I.e. if `max_iter_to_alter` is 25 and there are a total of `30` + denoising steps, the first 25 denoising steps will apply attend-and-excite and the last 5 will not + apply attend-and-excite. + thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): + Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. + scale_factor (`int`, *optional*, default to 20): + Scale factor that controls the step size of each Attend and Excite update. + attn_res (`tuple`, *optional*, default computed from width and height): + The 2D resolution of the semantic attention map. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. :type attention_store: object + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + token_indices, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if attn_res is None: + attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) + self.attention_store = AttentionStore(attn_res) + self.register_attention_control() + + # default config for step size from original repo + scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) + step_size = scale_factor * np.sqrt(scale_range) + + text_embeddings = ( + prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds + ) + + if isinstance(token_indices[0], int): + token_indices = [token_indices] + + indices = [] + + for ind in token_indices: + indices = indices + [ind] * num_images_per_prompt + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Attend and excite process + with torch.enable_grad(): + latents = latents.clone().detach().requires_grad_(True) + updated_latents = [] + for latent, index, text_embedding in zip(latents, indices, text_embeddings): + # Forward pass of denoising with text conditioning + latent = latent.unsqueeze(0) + text_embedding = text_embedding.unsqueeze(0) + + self.unet( + latent, + t, + encoder_hidden_states=text_embedding, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=index, + ) + + loss = self._compute_loss(max_attention_per_index=max_attention_per_index) + + # If this is an iterative refinement step, verify we have reached the desired threshold for all + if i in thresholds.keys() and loss > 1.0 - thresholds[i]: + loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( + latents=latent, + indices=index, + loss=loss, + threshold=thresholds[i], + text_embeddings=text_embedding, + step_size=step_size[i], + t=t, + ) + + # Perform gradient update + if i < max_iter_to_alter: + if loss != 0: + latent = self._update_latent( + latents=latent, + loss=loss, + step_size=step_size[i], + ) + logger.info(f"Iteration {i} | Loss: {loss:0.4f}") + + updated_latents.append(latent) + + latents = torch.cat(updated_latents, dim=0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class GaussianSmoothing(torch.nn.Module): + """ + Arguments: + Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input + using a depthwise convolution. + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the + gaussian kernel. dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 + def __init__( + self, + channels: int = 1, + kernel_size: int = 3, + sigma: float = 0.5, + dim: int = 2, + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * dim + if isinstance(sigma, float): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) + + def forward(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c7555e2ebad4c7f6045f3975b61f271a97ec8587 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -0,0 +1,28 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works +from ...utils import deprecate +from ..controlnet.multicontrolnet import MultiControlNetModel # noqa: F401 +from ..controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline # noqa: F401 + + +deprecate( + "stable diffusion controlnet", + "0.22.0", + "Importing `StableDiffusionControlNetPipeline` or `MultiControlNetModel` from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import StableDiffusionControlNetPipeline` instead.", + standard_warn=False, + stacklevel=3, +) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py new file mode 100644 index 0000000000000000000000000000000000000000..cae0f3a347de6248669a8508644c6ce7fe98d441 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -0,0 +1,727 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + depth_estimator: DPTForDepthEstimation, + feature_extractor: DPTFeatureExtractor, + ): + super().__init__() + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + depth_estimator=depth_estimator, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): + if isinstance(image, PIL.Image.Image): + image = [image] + else: + image = list(image) + + if isinstance(image[0], PIL.Image.Image): + width, height = image[0].size + elif isinstance(image[0], np.ndarray): + width, height = image[0].shape[:-1] + else: + height, width = image[0].shape[-2:] + + if depth_map is None: + pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=device) + # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. + # So we use `torch.autocast` here for half precision inference. + context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() + with context_manger: + depth_map = self.depth_estimator(pixel_values).predicted_depth + else: + depth_map = depth_map.to(device=device, dtype=dtype) + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(height // self.vae_scale_factor, width // self.vae_scale_factor), + mode="bicubic", + align_corners=False, + ) + + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 + depth_map = depth_map.to(dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if depth_map.shape[0] < batch_size: + repeat_by = batch_size // depth_map.shape[0] + depth_map = depth_map.repeat(repeat_by, 1, 1, 1) + + depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map + return depth_map + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + depth_map: Optional[torch.FloatTensor] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can accept image latents as `image` only if `depth_map` is not `None`. + depth_map (`torch.FloatTensor`, *optional*): + depth prediction that will be used as additional conditioning for the image generation process. If not + defined, it will automatically predicts the depth via `self.depth_estimator`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + ```py + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> from diffusers import StableDiffusionDepth2ImgPipeline + + >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-depth", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.to("cuda") + + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> init_image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "two tigers" + >>> n_propmt = "bad, deformed, ugly, bad anotomy" + >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare depth mask + depth_mask = self.prepare_depth_map( + image, + depth_map, + batch_size * num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds.dtype, + device, + ) + + # 5. Preprocess image + image = self.image_processor.preprocess(image) + + # 6. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d5953808f173fa9addf3206d958b2f7ec5c056 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1532 @@ +# Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + BaseOutput, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.FloatTensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: torch.FloatTensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents + ``` +""" + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def preprocess_mask(mask, batch_size: int = 1): + if not isinstance(mask, torch.Tensor): + # preprocess mask + if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): + mask = [mask] + + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = torch.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("`mask_image` should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask + + +class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler (`[DDIMInverseScheduler]`): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_source_inputs( + self, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, + ): + if source_prompt is not None and source_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif source_prompt is None and source_prompt_embeds is None: + raise ValueError( + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif source_prompt is not None and ( + not isinstance(source_prompt, str) and not isinstance(source_prompt, list) + ): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") + + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: + raise ValueError( + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + @torch.no_grad() + def generate_mask( + self, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + target_prompt: Optional[Union[str, List[str]]] = None, + target_negative_prompt: Optional[Union[str, List[str]]] = None, + target_prompt_embeds: Optional[torch.FloatTensor] = None, + target_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + source_prompt: Optional[Union[str, List[str]]] = None, + source_negative_prompt: Optional[Union[str, List[str]]] = None, + source_prompt_embeds: Optional[torch.FloatTensor] = None, + source_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + num_maps_per_mask: Optional[int] = 10, + mask_encode_strength: Optional[float] = 0.5, + mask_thresholding_ratio: Optional[float] = 3.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "np", + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function used to generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be used for computing the mask. + target_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass + `prompt_embeds`. instead. + target_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + target_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + target_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + source_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If + not defined, one has to pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If + not defined, one has to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from + `source_prompt` input argument. + source_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from + `source_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). + mask_encode_strength (`float`, *optional*, defaults to 0.5): + Conceptually, the strength of the noise maps sampled to generate the semantic mask using the method in + [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance]( + https://arxiv.org/pdf/2210.11427.pdf). Must be between 0 and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before + mask binarization. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a + `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel + binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise + the `np.array` will have shape `(batch_size, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 1. Check inputs (Provide dummy argument for callback_steps) + self.check_inputs( + target_prompt, + mask_encode_strength, + 1, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, + ) + + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if target_prompt is not None and isinstance(target_prompt, str): + batch_size = 1 + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) + else: + batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + target_prompt_embeds = self._encode_prompt( + target_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, + ) + + source_prompt_embeds = self._encode_prompt( + source_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) + encode_timestep = timesteps[0] + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents( + image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator + ) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if do_classifier_free_guidance: + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_source, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + torch.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) + mask_image = semantic_mask_image.cpu().numpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.image_processor.numpy_to_pil(mask_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return mask_image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + decode_latents: bool = False, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and + 1. When `strength` is 1, the inversion process will be run for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more + noise the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + will decode all inverted latents for each timestep into a list of generated images. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] + if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted + latents tensors ordered by increasing noise, and then second is the corresponding decoded images if + `decode_latents` is `True`, otherwise `None`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents( + image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator + ) + + # 5. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + if num_reg_steps > 0: + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + inverted_latents.append(latents.detach().clone()) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = torch.stack(list(reversed(inverted_latents)), 1) + + # 8. Post-processing + image = None + if decode_latents: + image = self.decode_latents(latents.flatten(0, 1)) + + # 9. Convert to PIL. + if decode_latents and output_type == "pil": + image = self.image_processor.numpy_to_pil(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image_latents: Union[torch.FloatTensor, PIL.Image.Image] = None, + inpaint_strength: Optional[float] = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask + will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be + converted to a single channel (luminance) before use. If it's a tensor, it should contain one color + channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `torch.FloatTensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image_latents` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." + ) + if image_latents is None: + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess mask + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = torch.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 6. Preprocess image latents + if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents): + image_latents = torch.cat(image_latents).detach() + elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5: + image_latents = image_latents.detach() + else: + image_latents = self.image_processor.preprocess(image_latents).detach() + + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): + raise ValueError( + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" + f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" + f" {image_latents.shape[1]} timesteps." + ) + image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) + image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + latents = image_latents[0].clone() + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + image_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcb55a7cff90cc06d1c05476013c05526f19f19 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -0,0 +1,394 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + Pipeline to generate variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPImageProcessor` + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4420f2838675d06b2558c37cd31a133915f423 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,764 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" + >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..b11ebfb6cfc55c888fd1fdfe2cb68151a6567c88 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,1030 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + + + It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such + as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default + text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with + this pipeline, but might be less performant. + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..049e3d18f3de867969af00e6e141335576286263 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,738 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image, batch_size): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, batch_size, scale_factor=8): + if not isinstance(mask, torch.FloatTensor): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = np.vstack([mask[None]] * batch_size) + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + else: + valid_mask_channel_sizes = [1, 3] + # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) + if mask.shape[3] in valid_mask_channel_sizes: + mask = mask.permute(0, 3, 1, 2) + elif mask.shape[1] not in valid_mask_channel_sizes: + raise ValueError( + f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," + f" but received mask of shape {tuple(mask.shape)}" + ) + # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape + mask = mask.mean(dim=1, keepdim=True) + h, w = mask.shape[-2:] + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 + mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) + return mask + + +class StableDiffusionInpaintPipelineLegacy( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): + image = image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = self.vae.config.scaling_factor * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # add noise to latents using the timesteps + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + add_predicted_noise: Optional[bool] = False, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the + expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to + that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image and mask + if not isinstance(image, torch.FloatTensor): + image = preprocess_image(image, batch_size) + + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + # encode the init image into latents and scale the latents + latents, init_latents_orig, noise = self.prepare_latents( + image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare mask latent + mask = mask_image.to(device=device, dtype=latents.dtype) + mask = torch.cat([mask] * num_images_per_prompt) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # masking + if add_predicted_noise: + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise_pred_uncond, torch.tensor([t]) + ) + else: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # use original latents corresponding to unmasked portions of the image + latents = (init_latents_orig * mask) + (latents * (1 - mask)) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py new file mode 100644 index 0000000000000000000000000000000000000000..341ff8daad4299a05b89a6b5d985399c1a7fb429 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -0,0 +1,758 @@ +# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also + accpet image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. This pipeline requires a value of at least `1`. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Image guidance scale is to push the generated image towards the inital image `image`. Image guidance + scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to + generate images that are closely linked to the source image `image`, usually at the expense of lower + image quality. This pipeline requires a value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInstructPix2PixPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + + >>> image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( + ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "make the mountains snowy" + >>> image = pipe(prompt=prompt, image=image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Check inputs + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 + # check if scheduler is in sigmas space + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False + )[0] + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. So we need to compute the + # predicted_original_sample here if we are using a karras style scheduler. + if scheduler_is_in_sigma_space: + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() + sigma = self.scheduler.sigmas[step_index] + noise_pred = latent_model_input - sigma * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. But the scheduler.step function + # expects the noise_pred and computes the predicted_original_sample internally. So we + # need to overwrite the noise_pred here such that the value of the computed + # predicted_original_sample is correct. + if scheduler_is_in_sigma_space: + noise_pred = (noise_pred - latents) / (-sigma) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..29a57470a341f4cf1a155af9e3e023091f9e55e8 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -0,0 +1,600 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import torch +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...pipelines import DiffusionPipeline +from ...schedulers import LMSDiscreteScheduler +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ModelWrapper: + def __init__(self, model, alphas_cumprod): + self.model = model + self.alphas_cumprod = alphas_cumprod + + def apply_model(self, *args, **kwargs): + if len(args) == 3: + encoder_hidden_states = args[-1] + args = args[:2] + if kwargs.get("cond", None) is not None: + encoder_hidden_states = kwargs.pop("cond") + return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample + + +class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + + + This is an experimental pipeline and is likely to change in the future. + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + logger.info( + f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use" + " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines" + " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for" + " production settings." + ) + + # get correct sigmas from LMS + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + model = ModelWrapper(unet, scheduler.alphas_cumprod) + if scheduler.config.prediction_type == "v_prediction": + self.k_diffusion_model = CompVisVDenoiser(model) + else: + self.k_diffusion_model = CompVisDenoiser(model) + + def set_scheduler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to + `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M + Karras`. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = True + if guidance_scale <= 1.0: + raise ValueError("has to use guidance_scale") + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + sigmas = sigmas.to(device) + else: + sigmas = self.scheduler.sigmas + sigmas = sigmas.to(prompt_embeds.dtype) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + # 7. Define model function + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + t = torch.cat([t] * 2) + + noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + # 8. Run k-diffusion solver + sampler_kwargs = {} + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5022c854a8f17dfbc09221177763a34f9d8d36 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -0,0 +1,503 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): + r""" + Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`EulerDiscreteScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: EulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") + + def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_out = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + text_embeddings = text_encoder_out.hidden_states[-1] + text_pooler_out = text_encoder_out.pooler_output + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + + uncond_encoder_out = self.text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + uncond_embeddings = uncond_encoder_out.hidden_states[-1] + uncond_pooler_out = uncond_encoder_out.pooler_output + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out]) + + return text_embeddings, text_pooler_out + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs(self, prompt, image, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] if image.ndim == 4 else 1 + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image upscaling. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be + either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It + will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an + image representation and encoded using this pipeline's `vae` encoder. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + ```py + >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline + >>> import torch + + + >>> pipeline = StableDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ) + >>> pipeline.to("cuda") + + >>> model_id = "stabilityai/sd-x2-latent-upscaler" + >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> upscaler.to("cuda") + + >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" + >>> generator = torch.manual_seed(33) + + >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images + + >>> with torch.no_grad(): + ... image = pipeline.decode_latents(low_res_latents) + >>> image = pipeline.numpy_to_pil(image)[0] + + >>> image.save("../images/a1.png") + + >>> upscaled_image = upscaler( + ... prompt=prompt, + ... image=low_res_latents, + ... num_inference_steps=20, + ... guidance_scale=0, + ... generator=generator, + ... ).images[0] + + >>> upscaled_image.save("../images/a2.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs(prompt, image, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if guidance_scale == 0: + prompt = [""] * batch_size + + # 3. Encode input prompt + text_embeddings, text_pooler_out = self._encode_prompt( + prompt, device, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + image = image.to(dtype=text_embeddings.dtype, device=device) + if image.shape[1] == 3: + # encode image if not in latent-space yet + image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = image[None, :] if image.ndim == 3 else image + image = torch.cat([image] * batch_multiplier) + + # 5. Add noise to image (set to be 0): + # (see below notes from the author): + # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." + noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) + noise_level = torch.cat([noise_level] * image.shape[0]) + inv_noise_level = (noise_level**2 + 1) ** (-0.5) + + image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] + image_cond = image_cond.to(text_embeddings.dtype) + + noise_level_embed = torch.cat( + [ + torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), + torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), + ], + dim=1, + ) + + timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height * 2, # 2x upscale + width * 2, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Denoising loop + num_warmup_steps = 0 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + sigma = self.scheduler.sigmas[i] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) + # preconditioning parameter based on Karras et al. (2022) (table 1) + timestep = torch.log(sigma) * 0.25 + + noise_pred = self.unet( + scaled_model_input, + timestep, + encoder_hidden_states=text_embeddings, + timestep_cond=timestep_condition, + ).sample + + # in original repo, the output contains a variance channel that's not used + noise_pred = noise_pred[:, :-1] + + # apply preconditioning, based on table 1 in Karras et al. (2022) + inv_sigma = 1 / (sigma**2 + 1) + noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..95dd207f9d12b48a38b1934e46d42dc6dc03e1d5 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -0,0 +1,675 @@ +# Copyright 2023 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessorLDM3D +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BaseOutput, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d") + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> output = pipe(prompt) + >>> rgb_image, depth_image = output.rgb, output.depth + ``` +""" + + +@dataclass +class LDM3DPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + rgb: Union[List[PIL.Image.Image], np.ndarray] + depth: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +class StableDiffusionLDM3DPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D: + https://arxiv.org/abs/2305.10853 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode rgb and depth images to and from latent + representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded rgb and depth latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + rgb_feature_extractor_input = feature_extractor_input[0] + safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 49, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return ((rgb, depth), has_nsfw_concept) + + return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py new file mode 100644 index 0000000000000000000000000000000000000000..2ecb3f9dbaf73b848552cd4aed386f40b59a1f87 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -0,0 +1,770 @@ +# Copyright 2023 TIME Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...schedulers.scheduling_utils import SchedulerMixin +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionModelEditingPipeline + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) + + >>> pipe = pipe.to("cuda") + + >>> source_prompt = "A pack of roses" + >>> destination_prompt = "A pack of blue roses" + >>> pipe.edit_model(source_prompt, destination_prompt) + + >>> prompt = "A field of roses" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + with_to_k ([`bool`]): + Whether to edit the key projection matrices along wiht the value projection matrices. + with_augs ([`list`]): + Textual augmentations to apply while editing the text-to-image model. Set to [] for no augmentations. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + with_to_k: bool = True, + with_augs: list = AUGS_CONST, + ): + super().__init__() + + if isinstance(scheduler, PNDMScheduler): + logger.error("PNDMScheduler for this pipeline is currently not supported.") + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.with_to_k = with_to_k + self.with_augs = with_augs + + # get cross-attention layers + ca_layers = [] + + def append_ca(net_): + if net_.__class__.__name__ == "CrossAttention": + ca_layers.append(net_) + elif hasattr(net_, "children"): + for net__ in net_.children(): + append_ca(net__) + + # recursively find all cross-attention layers in unet + for net in self.unet.named_children(): + if "down" in net[0]: + append_ca(net[1]) + elif "up" in net[0]: + append_ca(net[1]) + elif "mid" in net[0]: + append_ca(net[1]) + + # get projection matrices + self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] + self.projection_matrices = [l.to_v for l in self.ca_clip_layers] + self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] + if self.with_to_k: + self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] + self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def edit_model( + self, + source_prompt: str, + destination_prompt: str, + lamb: float = 0.1, + restart_params: bool = True, + ): + r""" + Apply model editing via closed-form solution (see Eq. 5 in the TIME paper https://arxiv.org/abs/2303.08084) + + Args: + source_prompt (`str`): + The source prompt containing the concept to be edited. + destination_prompt (`str`): + The destination prompt. Must contain all words from source_prompt with additional ones to specify the + target edit. + lamb (`float`, *optional*, defaults to 0.1): + The lambda parameter specifying the regularization intesity. Smaller values increase the editing power. + restart_params (`bool`, *optional*, defaults to True): + Restart the model parameters to their pre-trained version before editing. This is done to avoid edit + compounding. When it is False, edits accumulate. + """ + + # restart LDM parameters + if restart_params: + num_ca_clip_layers = len(self.ca_clip_layers) + for idx_, l in enumerate(self.ca_clip_layers): + l.to_v = copy.deepcopy(self.og_matrices[idx_]) + self.projection_matrices[idx_] = l.to_v + if self.with_to_k: + l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) + self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k + + # set up sentences + old_texts = [source_prompt] + new_texts = [destination_prompt] + # add augmentations + base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] + for aug in self.with_augs: + old_texts.append(aug + base) + base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] + for aug in self.with_augs: + new_texts.append(aug + base) + + # prepare input k* and v* + old_embs, new_embs = [], [] + for old_text, new_text in zip(old_texts, new_texts): + text_input = self.tokenizer( + [old_text, new_text], + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + old_emb, new_emb = text_embeddings + old_embs.append(old_emb) + new_embs.append(new_emb) + + # identify corresponding destinations for each token in old_emb + idxs_replaces = [] + for old_text, new_text in zip(old_texts, new_texts): + tokens_a = self.tokenizer(old_text).input_ids + tokens_b = self.tokenizer(new_text).input_ids + tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] + tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] + num_orig_tokens = len(tokens_a) + idxs_replace = [] + j = 0 + for i in range(num_orig_tokens): + curr_token = tokens_a[i] + while tokens_b[j] != curr_token: + j += 1 + idxs_replace.append(j) + j += 1 + while j < 77: + idxs_replace.append(j) + j += 1 + while len(idxs_replace) < 77: + idxs_replace.append(76) + idxs_replaces.append(idxs_replace) + + # prepare batch: for each pair of setences, old context and new values + contexts, valuess = [], [] + for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): + context = old_emb.detach() + values = [] + with torch.no_grad(): + for layer in self.projection_matrices: + values.append(layer(new_emb[idxs_replace]).detach()) + contexts.append(context) + valuess.append(values) + + # edit the model + for layer_num in range(len(self.projection_matrices)): + # mat1 = \lambda W + \sum{v k^T} + mat1 = lamb * self.projection_matrices[layer_num].weight + + # mat2 = \lambda I + \sum{k k^T} + mat2 = lamb * torch.eye( + self.projection_matrices[layer_num].weight.shape[1], + device=self.projection_matrices[layer_num].weight.device, + ) + + # aggregate sums for mat1, mat2 + for context, values in zip(contexts, valuess): + context_vector = context.reshape(context.shape[0], context.shape[1], 1) + context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) + value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) + for_mat1 = (value_vector @ context_vector_T).sum(dim=0) + for_mat2 = (context_vector @ context_vector_T).sum(dim=0) + mat1 += for_mat1 + mat2 += for_mat2 + + # update projection matrix + self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..37e705d1bc5a822a5ae6242139123f3a14b76a68 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -0,0 +1,739 @@ +# Copyright 2023 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler +from ...utils import logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler + + >>> model_ckpt = "stabilityai/stable-diffusion-2-base" + >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") + >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( + ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 + ... ) + + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of the dolomites" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image + Generation". + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + To generate panorama-like images, be sure to pass the `width` parameter accordingly when using the pipeline. Our + recommendation for the `width` value is 2048. This is the default value of the `width` parameter for this pipeline. + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. The original work + on Multi Diffsion used the [`DDIMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def decode_latents_with_padding(self, latents, padding=8): + # Add padding to latents for circular inference + # padding is the number of latents to add on each side + # it would slightly increase the memory usage, but remove the boundary artifacts + latents = 1 / self.vae.config.scaling_factor * latents + latents_left = latents[..., :padding] + latents_right = latents[..., -padding:] + latents = torch.cat((latents_right, latents, latents_left), axis=-1) + image = self.vae.decode(latents, return_dict=False)[0] + padding_pix = self.vae_scale_factor * padding + image = image[..., padding_pix:-padding_pix] + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False): + # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) + # if panorama's height/width < window_size, num_blocks of height/width should return 1 + panorama_height /= 8 + panorama_width /= 8 + num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 + if circular_padding: + num_blocks_width = panorama_width // stride if panorama_width > window_size else 1 + else: + num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + views.append((h_start, h_end, w_start, w_end)) + return views + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 2048, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + view_batch_size: int = 1, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + circular_padding: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 512: + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 2048): + The width in pixels of the generated image. The width is kept to a high number because the + pipeline is supposed to be used for generating panorama-like images. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + view_batch_size (`int`, *optional*, defaults to 1): + The batch size to denoise splited views. For some GPUs with high performance, higher view batch size + can speedup the generation and increase the VRAM usage. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + circular_padding (`bool`, *optional*, defaults to `False`): + If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular + padding allows the model to seamlessly generate a transition from the rightmost part of the image to + the leftmost part, maintaining consistency in a 360-degree sense. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Define panorama grid and initialize views for synthesis. + # prepare batch grid + views = self.get_views(height, width, circular_padding=circular_padding) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) + count = torch.zeros_like(latents) + value = torch.zeros_like(latents) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + # Each denoising step also includes refinement of the latents with respect to the + # views. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + count.zero_() + value.zero_() + + # generate views + # Here, we iterate through different spatial crops of the latents and denoise them. These + # denoised (latent) crops are then averaged to produce the final latent + # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the + # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 + # Batch views denoise + for j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) + # get the latents corresponding to the current view coordinates + if circular_padding: + latents_for_view = [] + for h_start, h_end, w_start, w_end in batch_view: + if w_end > latents.shape[3]: + # Add circular horizontal padding + latent_view = torch.cat( + ( + latents[:, :, h_start:h_end, w_start:], + latents[:, :, h_start:h_end, : w_end - latents.shape[3]], + ), + axis=-1, + ) + else: + latent_view = latents[:, :, h_start:h_end, w_start:w_end] + latents_for_view.append(latent_view) + latents_for_view = torch.cat(latents_for_view) + else: + latents_for_view = torch.cat( + [ + latents[:, :, h_start:h_end, w_start:w_end] + for h_start, h_end, w_start, w_end in batch_view + ] + ) + + # rematch block's scheduler status + self.scheduler.__dict__.update(views_scheduler_status[j]) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + latents_for_view.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latents_for_view + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # repeat prompt_embeds for batch + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs + ).prev_sample + + # save views scheduler status after sample + views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) + + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + if circular_padding and w_end > latents.shape[3]: + # Case for circular padding + value[:, :, h_start:h_end, w_start:] += latents_view_denoised[ + :, :, h_start:h_end, : latents.shape[3] - w_start + ] + value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[ + :, :, h_start:h_end, latents.shape[3] - w_start : + ] + count[:, :, h_start:h_end, w_start:] += 1 + count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1 + else: + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 + latents = torch.where(count > 0, value / count, value) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + if circular_padding: + image = self.decode_latents_with_padding(latents) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py new file mode 100644 index 0000000000000000000000000000000000000000..073f02e8ee98a519f85b4437a46557f0e0f00c7b --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -0,0 +1,787 @@ +# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DDPMParallelScheduler + >>> from diffusers import StableDiffusionParadigmsPipeline + + >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") + + >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 + >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] + ``` +""" + + +class StableDiffusionParadigmsPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline + parallelizes the denoising steps to generate a single image faster (more akin to model parallelism). + + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs + self.wrapped_unet = self.unet + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _cumsum(self, input, dim, debug=False): + if debug: + # cumsum_cuda_kernel does not have a deterministic implementation + # so perform cumsum on cpu for debugging purposes + return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) + else: + return torch.cumsum(input, dim=dim) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + parallel: int = 10, + tolerance: float = 0.1, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + debug: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + parallel (`int`, *optional*, defaults to 10): + The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but + requires higher memory usage and also can require more total FLOPs. + tolerance (`float`, *optional*, defaults to 0.1): + The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower + tolerance usually leads to less/no degradation. Higher tolerance is faster but can risk degradation of + sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + debug (`bool`, *optional*, defaults to `False`): + Whether or not to run in debug mode. In debug mode, torch.cumsum is evaluated using the CPU. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + extra_step_kwargs.pop("generator", None) + + # # 7. Denoising loop + scheduler = self.scheduler + parallel = min(parallel, len(scheduler.timesteps)) + + begin_idx = 0 + end_idx = parallel + latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) + + # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. + # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. + noise_array = torch.zeros_like(latents_time_evolution_buffer) + for j in range(len(scheduler.timesteps)): + base_noise = randn_tensor( + shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype + ) + noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise + noise_array[j] = noise.clone() + + # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance + # outside of the denoising loop to avoid recomputing it at every step. + # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. + inverse_variance_norm = 1.0 / torch.tensor( + [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] + ).to(noise_array.device) + latent_dim = noise_array[0, 0].numel() + inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim + + scaled_tolerance = tolerance**2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + steps = 0 + while begin_idx < len(scheduler.timesteps): + # these have shape (parallel_dim, 2*batch_size, ...) + # parallel_len is at most parallel, but could be less if we are at the end of the timesteps + # we are processing batch window of timesteps spanning [begin_idx, end_idx) + parallel_len = end_idx - begin_idx + + block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) + block_latents = latents_time_evolution_buffer[begin_idx:end_idx] + block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) + t_vec = block_t + if do_classifier_free_guidance: + t_vec = t_vec.repeat(1, 2) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) + + # if parallel_len is small, no need to use multiple GPUs + net = self.wrapped_unet if parallel_len > 3 else self.unet + # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] + model_output = net( + latent_model_input.flatten(0, 1), + t_vec.flatten(0, 1), + encoder_hidden_states=block_prompt_embeds.flatten(0, 1), + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + per_latent_shape = model_output.shape[1:] + if do_classifier_free_guidance: + model_output = model_output.reshape( + parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape + ) + noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] + model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + model_output = model_output.reshape( + parallel_len * batch_size * num_images_per_prompt, *per_latent_shape + ) + + block_latents_denoise = scheduler.batch_step_no_noise( + model_output=model_output, + timesteps=block_t.flatten(0, 1), + sample=block_latents.flatten(0, 1), + **extra_step_kwargs, + ).reshape(block_latents.shape) + + # back to shape (parallel_dim, batch_size, ...) + # now we want to add the pre-sampled noise + # parallel sampling algorithm requires computing the cumulative drift from the beginning + # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. + delta = block_latents_denoise - block_latents + cumulative_delta = self._cumsum(delta, dim=0, debug=debug) + cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) + + # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise + if scheduler._is_ode_scheduler: + cumulative_noise = 0 + + block_latents_new = ( + latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise + ) + cur_error = torch.linalg.norm( + (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( + parallel_len, batch_size * num_images_per_prompt, -1 + ), + dim=-1, + ).pow(2) + error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] + + # find the first index of the vector error_ratio that is greater than error tolerance + # we can shift the window for the next iteration up to this index + error_ratio = torch.nn.functional.pad( + error_ratio, (0, 0, 0, 1), value=1e9 + ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension + any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() + ind = torch.argmax(any_error_at_time).item() + + # compute the new begin and end idxs for the window + new_begin_idx = begin_idx + min(1 + ind, parallel) + new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) + + # store the computed latents for the current window in the global buffer + latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new + # initialize the new sliding window latents with the end of the current window, + # should be better than random initialization + latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ + None, + ] + + steps += 1 + + progress_bar.update(new_begin_idx - begin_idx) + if callback is not None and steps % callback_steps == 0: + callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) + + begin_idx = new_begin_idx + end_idx = new_end_idx + + latents = latents_time_evolution_buffer[-1] + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..960c4369e45ab8792e90c8aaa253f52ac65cef5c --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -0,0 +1,1259 @@ +# Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import Attention +from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler +from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler +from ...utils import ( + PIL_INTERPOLATION, + BaseOutput, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.FloatTensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + latents: torch.FloatTensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + + >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline + + + >>> def download(embedding_url, local_filepath): + ... r = requests.get(embedding_url) + ... with open(local_filepath, "wb") as f: + ... f.write(r.content) + + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.to("cuda") + + >>> prompt = "a high resolution painting of a cat in the style of van gough" + >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" + >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" + + >>> for url in [source_emb_url, target_emb_url]: + ... download(url, url.split("/")[-1]) + + >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) + >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) + >>> images = pipeline( + ... prompt, + ... source_embeds=src_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... ).images + + >>> images[0].save("edited_image_dog.png") + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import BlipForConditionalGeneration, BlipProcessor + >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + + >>> import requests + >>> from PIL import Image + + >>> captioner_id = "Salesforce/blip-image-captioning-base" + >>> processor = BlipProcessor.from_pretrained(captioner_id) + >>> model = BlipForConditionalGeneration.from_pretrained( + ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ... ) + + >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + ... sd_model_ckpt, + ... caption_generator=model, + ... caption_processor=processor, + ... torch_dtype=torch.float16, + ... safety_checker=None, + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + >>> # generate caption + >>> caption = pipeline.generate_caption(raw_image) + + >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" + >>> inv_latents = pipeline.invert(caption, image=raw_image).latents + >>> # we need to generate source and target embeds + + >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + + >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + >>> source_embeds = pipeline.get_embeds(source_prompts) + >>> target_embeds = pipeline.get_embeds(target_prompts) + >>> # the latents can then be used to edit a real image + >>> # when using Stable Diffusion 2 or other models that use v-prediction + >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion + + >>> image = pipeline( + ... caption, + ... source_embeds=source_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... generator=generator, + ... latents=inv_latents, + ... negative_prompt=caption, + ... ).images[0] + >>> image.save("edited_image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def prepare_unet(unet: UNet2DConditionModel): + """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" + pix2pix_zero_attn_procs = {} + for name in unet.attn_processors.keys(): + module_name = name.replace(".processor", "") + module = unet.get_submodule(module_name) + if "attn2" in name: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) + module.requires_grad_(True) + else: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) + module.requires_grad_(False) + + unet.set_attn_processor(pix2pix_zero_attn_procs) + return unet + + +class Pix2PixZeroL2Loss: + def __init__(self): + self.loss = 0.0 + + def compute_loss(self, predictions, targets): + self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) + + +class Pix2PixZeroAttnProcessor: + """An attention processor class to store the attention weights. + In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" + + def __init__(self, is_pix2pix_zero=False): + self.is_pix2pix_zero = is_pix2pix_zero + if self.is_pix2pix_zero: + self.reference_cross_attn_map = {} + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + timestep=None, + loss=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + if self.is_pix2pix_zero and timestep is not None: + # new bookkeeping to save the attention weights. + if loss is None: + self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() + # compute loss + elif loss is not None: + prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) + loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): + r""" + Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + requires_safety_checker (bool): + Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the + pipeline publicly. + """ + _optional_components = [ + "safety_checker", + "feature_extractor", + "caption_generator", + "caption_processor", + "inverse_scheduler", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], + feature_extractor: CLIPImageProcessor, + safety_checker: StableDiffusionSafetyChecker, + inverse_scheduler: DDIMInverseScheduler, + caption_generator: BlipForConditionalGeneration, + caption_processor: BlipProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + caption_processor=caption_processor, + caption_generator=caption_generator, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if source_embeds is None and target_embeds is None: + raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def generate_caption(self, images): + """Generates caption for a given image.""" + text = "a photography of" + + prev_device = self.caption_generator.device + + device = self._execution_device + inputs = self.caption_processor(images, text, return_tensors="pt").to( + device=device, dtype=self.caption_generator.dtype + ) + self.caption_generator.to(device) + outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + self.caption_generator.to(prev_device) + + caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption + + def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): + """Constructs the edit direction to steer the image generation process semantically.""" + return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + + @torch.no_grad() + def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor: + num_prompts = len(prompt) + embeds = [] + for i in range(0, num_prompts, batch_size): + prompt_slice = prompt[i : i + batch_size] + + input_ids = self.tokenizer( + prompt_slice, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + input_ids = input_ids.to(self.text_encoder.device) + embeds.append(self.text_encoder(input_ids)[0]) + + return torch.cat(embeds, dim=0).mean(0)[None] + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + def auto_corr_loss(self, hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + return reg_loss + + def kl_divergence(self, hidden_states): + mean = hidden_states.mean() + var = hidden_states.var() + return var + mean**2 - 1 - torch.log(var + 1e-7) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + source_embeds: torch.Tensor = None, + target_embeds: torch.Tensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_guidance_amount: float = 0.1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + source_embeds (`torch.Tensor`): + Source concept embeddings. Generation of the embeddings as per the [original + paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. + target_embeds (`torch.Tensor`): + Target concept embeddings. Generation of the embeddings as per the [original + paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Define the spatial resolutions. + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds, + ) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Generate the inverted noise from the input image or any other image + # generated from the input prompt. + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents_init = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Compute the edit directions. + edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) + + # 9. Edit the prompt embeddings as per the edit directions discovered. + prompt_embeds_edit = prompt_embeds.clone() + prompt_embeds_edit[1:2] += edit_direction + + # 10. Second denoising loop to generate the edited image. + latents = latents_init + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # we want to learn the latent such that it steers the generation + # process towards the edited direction, so make the make initial + # noise learnable + x_in = latent_model_input.detach().clone() + x_in.requires_grad = True + + # optimizer + opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) + + with torch.enable_grad(): + # initialize loss + loss = Pix2PixZeroL2Loss() + + # predict the noise residual + noise_pred = self.unet( + x_in, + t, + encoder_hidden_states=prompt_embeds_edit.detach(), + cross_attention_kwargs={"timestep": t, "loss": loss}, + ).sample + + loss.loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + noise_pred = self.unet( + x_in.detach(), + t, + encoder_hidden_states=prompt_embeds_edit, + cross_attention_kwargs={"timestep": None}, + ).sample + + latents = x_in.detach().chunk(2)[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: Optional[str] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + num_inference_steps: int = 50, + guidance_scale: float = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_guidance_amount: float = 0.1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 5, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet + image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 5): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or + `tuple`: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted + latents tensor and then second is the corresponding decoded image. + """ + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) + + # 5. Encode input prompt + num_images_per_prompt = 1 + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + ) + + # 4. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.inverse_scheduler.timesteps + + # 6. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + inverted_latents = latents.detach().clone() + + # 8. Post-processing + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (inverted_latents, image) + + return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py new file mode 100644 index 0000000000000000000000000000000000000000..9c583de9ca9c03713262c0e103e7f32eca9eb4f6 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -0,0 +1,767 @@ +# Copyright 2023 Susung Hong and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionSAGPipeline + + >>> pipe = StableDiffusionSAGPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, sag_scale=0.75).images[0] + ``` +""" + + +# processes and stores attention probabilities +class CrossAttnStoreProcessor: + def __init__(self): + self.attention_probs = None + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + self.attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(self.attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input +class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + sag_scale: float = 0.75, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + sag_scale (`float`, *optional*, defaults to 0.75): + SAG scale as defined in [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance] + (https://arxiv.org/abs/2210.00939). `sag_scale` is defined as `s_s` of equation (24) of SAG paper: + https://arxiv.org/pdf/2210.00939.pdf. Typically chosen between [0, 1.0] for better quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # and `sag_scale` is` `s` of equation (16) + # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf + # `sag_scale = 0` means no self-attention guidance + do_self_attention_guidance = sag_scale > 0.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + store_processor = CrossAttnStoreProcessor() + self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + map_size = None + + def get_map_size(module, input, output): + nonlocal map_size + map_size = output[0].shape[-2:] + + with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform self-attention guidance with the stored self-attentnion map + if do_self_attention_guidance: + # classifier-free guidance produces two chunks of attention map + # and we only use unconditional one according to equation (25) + # in https://arxiv.org/pdf/2210.00939.pdf + if do_classifier_free_guidance: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) + # get the stored attention maps + uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) + ) + uncond_emb, _ = prompt_embeds.chunk(2) + # forward and give guidance + degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample + noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) + else: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred, t) + # get the stored attention maps + cond_attn = store_processor.attention_probs + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) + ) + # forward and give guidance + degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample + noise_pred += sag_scale * (noise_pred - degraded_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def sag_masking(self, original_latents, attn_map, map_size, t, eps): + # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf + bh, hw1, hw2 = attn_map.shape + b, latent_channel, latent_h, latent_w = original_latents.shape + h = self.unet.config.attention_head_dim + if isinstance(h, list): + h = h[-1] + + # Produce attention mask + attn_map = attn_map.reshape(b, h, hw1, hw2) + attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 + attn_mask = ( + attn_mask.reshape(b, map_size[0], map_size[1]) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) + ) + attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) + + # Blur according to the self-attention mask + degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) + degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + + # Noise it again to match the noise level + degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) + + return degraded_latents + + # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step + # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) + def pred_x0(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_original_sample + + def pred_epsilon(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_eps = model_output + elif self.scheduler.config.prediction_type == "sample": + pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) + elif self.scheduler.config.prediction_type == "v_prediction": + pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_eps + + +# Gaussian blur +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..61b1419d5ced4e1c90d6ec2dcecadafce38b2d9e --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -0,0 +1,762 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers +from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-guided image super-resolution using Stable Diffusion 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + low_res_scheduler ([`SchedulerMixin`]): + A scheduler used to add initial noise to the low res conditioning image. It must be an instance of + [`DDPMScheduler`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + _optional_components = ["watermarker", "safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + scheduler: KarrasDiffusionSchedulers, + safety_checker: Optional[Any] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + watermarker: Optional[Any] = None, + max_noise_level: int = 350, + ): + super().__init__() + + if hasattr( + vae, "config" + ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" + " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" + " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + safety_checker=safety_checker, + watermarker=watermarker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") + self.register_to_config(max_noise_level=max_noise_level) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import StableDiffusionUpscalePipeline + >>> import torch + + >>> # load model and scheduler + >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" + >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( + ... model_id, revision="fp16", torch_dtype=torch.float16 + ... ) + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + >>> prompt = "a white cat" + + >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] + >>> upscaled_image.save("upsampled_cat.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = torch.cat([image] * batch_multiplier * num_images_per_prompt) + noise_level = torch.cat([noise_level] * image.shape[0]) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=noise_level, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 11. Apply watermark + if output_type == "pil" and self.watermarker is not None: + image = self.watermarker.apply_watermark(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..7c89bfedbd5903409d33c19cab47b69cd6dcc926 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -0,0 +1,914 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableUnCLIPPipeline + + >>> pipe = StableUnCLIPPipeline.from_pretrained( + ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 + ... ) # TODO update model path + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> images = pipe(prompt).images + >>> images[0].save("astronaut_horse.png") + ``` +""" + + +class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + """ + Pipeline for text-to-image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior_tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + prior_scheduler ([`KarrasDiffusionSchedulers`]): + Scheduler used in the prior denoising process. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by `noise_level` in `StableUnCLIPPipeline.__call__`. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + _exclude_from_cpu_offload = ["prior", "image_normalizer"] + + # prior components + prior_tokenizer: CLIPTokenizer + prior_text_encoder: CLIPTextModelWithProjection + prior: PriorTransformer + prior_scheduler: KarrasDiffusionSchedulers + + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + + vae: AutoencoderKL + + def __init__( + self, + # prior components + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModelWithProjection, + prior: PriorTransformer, + prior_scheduler: KarrasDiffusionSchedulers, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_encoder, + prior=prior, + prior_scheduler=prior_scheduler, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder + def _encode_prior_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.prior_tokenizer( + prompt, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.prior_tokenizer.batch_decode( + untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] + + prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) + + prompt_embeds = prior_text_encoder_output.text_embeds + prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.prior_tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( + uncond_input.input_ids.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds + uncond_prior_text_encoder_hidden_states = ( + negative_prompt_embeds_prior_text_encoder_output.last_hidden_state + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_prior_text_encoder_hidden_states.shape[1] + uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( + 1, num_images_per_prompt, 1 + ) + uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prior_text_encoder_hidden_states = torch.cat( + [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states] + ) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, prior_text_encoder_hidden_states, text_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler + def prepare_prior_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the prior_scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + self.image_normalizer.to(image_embeds.device) + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + # regular denoising process args + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + # prior args + prior_num_inference_steps: int = 25, + prior_guidance_scale: float = 4.0, + prior_latents: Optional[torch.FloatTensor] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps in the prior denoising process. More denoising steps usually lead to a + higher quality image at the expense of slower inference. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion + Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of + [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + embedding generation in the prior denoising process. Can be used to tweak the same generation with + different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied + random `generator`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 + + # 3. Encode input prompt + prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=prior_do_classifier_free_guidance, + ) + + # 4. Prepare prior timesteps + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + # 5. Prepare prior latent variables + embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prior_prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) + + # 7. Prior denoising loop + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents + latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prior_prompt_embeds, + encoder_hidden_states=prior_text_encoder_hidden_states, + attention_mask=prior_text_mask, + ).predicted_image_embedding + + if prior_do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + **prior_extra_step_kwargs, + return_dict=False, + )[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, prior_latents) + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeds = prior_latents + + # done prior + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 8. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 9. Prepare image embeddings + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + # 10. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 11. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + latents = self.prepare_latents( + shape=shape, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + scheduler=self.scheduler, + ) + + # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 13. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..003c82ff4f8aedaa9874fb18d9abef858195e4b9 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -0,0 +1,810 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.utils.import_utils import is_accelerate_available + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableUnCLIPImg2ImgPipeline + + >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 + ... ) # TODO update model path + >>> pipe = pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt, init_image).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + """ + Pipeline for text-guided image to image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + feature_extractor ([`CLIPImageProcessor`]): + Feature extractor for image pre-processing before being encoded. + image_encoder ([`CLIPVisionModelWithProjection`]): + CLIP vision model for encoding images. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by `noise_level` in `StableUnCLIPPipeline.__call__`. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + _exclude_from_cpu_offload = ["image_normalizer"] + + # image encoding components + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + + vae: AutoencoderKL + + def __init__( + self, + # image encoding components + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + feature_extractor=feature_extractor, + image_encoder=image_encoder, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def _encode_image( + self, + image, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level, + generator, + image_embeds, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if isinstance(image, PIL.Image.Image): + # the image embedding should repeated so it matches the total batch size of the prompt + repeat_by = batch_size + else: + # assume the image input is already properly batched and just needs to be repeated so + # it matches the num_images_per_prompt. + # + # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched + # `image_embeds`. If those happen to be common use cases, let's think harder about + # what the expected dimensions of inputs should be and how we handle the encoding. + repeat_by = num_images_per_prompt + + if image_embeds is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeds = image_embeds.unsqueeze(1) + bs_embed, seq_len, _ = image_embeds.shape + image_embeds = image_embeds.repeat(1, repeat_by, 1) + image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) + image_embeds = image_embeds.squeeze(1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + if image is not None and image_embeds is not None: + raise ValueError( + "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." + ) + + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + + if image is not None: + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + self.image_normalizer.to(image_embeds.device) + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + image_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be + used or prompt is initialized to `""`. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which + the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the + latents in the denoising process such as in the standard stable diffusion text guided image variation + process. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in + the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as + `latents`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if prompt is None and prompt_embeds is None: + prompt = len(image) * [""] if isinstance(image, list) else "" + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Encoder input image + noise_level = torch.tensor([noise_level], device=device) + image_embeds = self._encode_image( + image=image, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + noise_level=noise_level, + generator=generator, + image_embeds=image_embeds, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion/safety_checker.py b/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..38c7b22d08d43ade5fe7979f5514ec973109fd82 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,125 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/diffusers/pipelines/stable_diffusion/safety_checker_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8c3167954016b3b89f16caf8348661cd3a27ef --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -0,0 +1,112 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPConfig, FlaxPreTrainedModel +from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule + + +def jax_cosine_distance(emb_1, emb_2, eps=1e-12): + norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T + norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T + return jnp.matmul(norm_emb_1, norm_emb_2.T) + + +class FlaxStableDiffusionSafetyCheckerModule(nn.Module): + config: CLIPConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) + + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) + self.special_care_embeds = self.param( + "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) + ) + + self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) + self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) + + def __call__(self, clip_input): + pooled_output = self.vision_model(clip_input)[1] + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment + special_scores = jnp.round(special_scores, 3) + is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) + # Use a lower threshold if an image has any special care concept + special_adjustment = is_special_care * 0.01 + + concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment + concept_scores = jnp.round(concept_scores, 3) + has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) + + return has_nsfw_concepts + + +class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + module_class = FlaxStableDiffusionSafetyCheckerModule + + def __init__( + self, + config: CLIPConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = (1, 224, 224, 3) + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensor + clip_input = jax.random.normal(rng, input_shape) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, clip_input)["params"] + + return random_params + + def __call__( + self, + clip_input, + params: dict = None, + ): + clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) + + return self.module.apply( + {"params": params or self.params}, + jnp.array(clip_input, dtype=jnp.float32), + rngs={}, + ) diff --git a/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7362df7e80e72719133f1804600a618fe161f668 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -0,0 +1,57 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__( + self, + embedding_dim: int = 768, + ): + super().__init__() + + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + ): + self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) + self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) + return self + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = (embeds * self.std) + self.mean + return embeds diff --git a/diffusers/pipelines/stable_diffusion_safe/__init__.py b/diffusers/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5aecfeac112e53b2fc49278c1acaa95a6c0c7257 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +class SafetyConfig(object): + WEAK = { + "sld_warmup_steps": 15, + "sld_guidance_scale": 20, + "sld_threshold": 0.0, + "sld_momentum_scale": 0.0, + "sld_mom_beta": 0.0, + } + MEDIUM = { + "sld_warmup_steps": 10, + "sld_guidance_scale": 1000, + "sld_threshold": 0.01, + "sld_momentum_scale": 0.3, + "sld_mom_beta": 0.4, + } + STRONG = { + "sld_warmup_steps": 7, + "sld_guidance_scale": 2000, + "sld_threshold": 0.025, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + MAX = { + "sld_warmup_steps": 0, + "sld_guidance_scale": 5000, + "sld_threshold": 1.0, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + + +@dataclass +class StableDiffusionSafePipelineOutput(BaseOutput): + """ + Output class for Safe Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" + (nsfw) content, or `None` if no safety check was performed or no images were flagged. + applied_safety_concept (`str`) + The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + applied_safety_concept: Optional[str] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe + from .safety_checker import SafeStableDiffusionSafetyChecker diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56217772ec01126c22f6242741e3e78242cd7a94 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ebc2d91b192ff22a75ebcf1933075d07f939c17 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..741a61b20498dd1ec9aa6c66413a64cf6ff4bbb9 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c56960c474dd51e5ed783345f629fb28bdc79d Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa82e4bbe818c9c78f8df967aa1c0ab852cbfc53 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8150b284dd66d8c9982d83735c3ef37d092df515 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py new file mode 100644 index 0000000000000000000000000000000000000000..f172575bc6c737a35a6eaed79e4b8915d2b735e6 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -0,0 +1,705 @@ +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionSafePipelineOutput +from .safety_checker import SafeStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipelineSafe(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Safe Latent Diffusion. + + The implementation is based on the [`StableDiffusionPipeline`] + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: SafeStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + safety_concept: Optional[str] = ( + "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," + " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" + " abuse, brutality, cruelty" + ) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + def safety_concept(self): + r""" + Getter method for the safety concept used with SLD + + Returns: + `str`: The text describing the safety concept + """ + return self._safety_text_concept + + @safety_concept.setter + def safety_concept(self, concept): + r""" + Setter method for the safety concept used with SLD + + Args: + concept (`str`): + The text of the new safety concept + """ + self._safety_text_concept = concept + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + enable_safety_guidance, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # Encode the safety concept text + if enable_safety_guidance: + safety_concept_input = self.tokenizer( + [self._safety_text_concept], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] + + # duplicate safety embeddings for each generation per prompt, using mps friendly method + seq_len = safety_embeddings.shape[1] + safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) + safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance + sld, we need to do three forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing three forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) + + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype, enable_safety_guidance): + if self.safety_checker is not None: + images = image.copy() + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + flagged_images = np.zeros((2, *image.shape[1:])) + if any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead." + f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" + ) + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + flagged_images[idx] = images[idx] + image[idx] = np.zeros(image[idx].shape) # black image + else: + has_nsfw_concept = None + flagged_images = None + return image, has_nsfw_concept, flagged_images + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def perform_safety_guidance( + self, + enable_safety_guidance, + safety_momentum, + noise_guidance, + noise_pred_out, + i, + sld_guidance_scale, + sld_warmup_steps, + sld_threshold, + sld_momentum_scale, + sld_mom_beta, + ): + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + return noise_guidance, safety_momentum + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + sld_guidance_scale: Optional[float] = 1000, + sld_warmup_steps: Optional[int] = 10, + sld_threshold: Optional[float] = 0.01, + sld_momentum_scale: Optional[float] = 0.3, + sld_mom_beta: Optional[float] = 0.4, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + sld_guidance_scale (`float`, *optional*, defaults to 1000): + Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). + `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be + disabled. + sld_warmup_steps (`int`, *optional*, defaults to 10): + Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than + `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + sld_threshold (`float`, *optional*, defaults to 0.01): + Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold` + is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). + sld_momentum_scale (`float`, *optional*, defaults to 0.3): + Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 + momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + sld_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous + momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance + if not enable_safety_guidance: + warnings.warn("Safety checker disabled!") + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + safety_momentum = None + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond + + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, + torch.zeros_like(scale), + scale, + ) + + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept, flagged_images = self.run_safety_checker( + image, device, prompt_embeds.dtype, enable_safety_guidance + ) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + if flagged_images is not None: + flagged_images = self.numpy_to_pil(flagged_images) + + if not return_dict: + return ( + image, + has_nsfw_concept, + self._safety_text_concept if enable_safety_guidance else None, + flagged_images, + ) + + return StableDiffusionSafePipelineOutput( + images=image, + nsfw_content_detected=has_nsfw_concept, + applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, + unsafe_images=flagged_images, + ) diff --git a/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0c547496a0202dbfa1d8525a92565b3df62cbb --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py @@ -0,0 +1,109 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafeStableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + return images, has_nsfw_concepts diff --git a/diffusers/pipelines/stable_diffusion_xl/__init__.py b/diffusers/pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b823cddd3af0915ce6c6afc6717024c34cff886 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL + +from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available + + +@dataclass +class StableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): + from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline + from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline + from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9698b2cfae0d63542a8fef997eb18d3b5249338d Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5042985a37ca3373dc63354236b6da7b89a2ba6a Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..266ba7d3383c96e492cc106e9c94cb16c57941f7 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ff60e5db97d24f42e27bcbdd97a2e4f30d2b182 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cd0434f18b30d15a9e385c392b8fffcfca576a6 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96fde8de78d8ef0add27c1edbc060db10b8e6868 Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b07704f35a52279818d9efcbce001193a7aafaa Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c12b4e004ff5ae13ee4a8aaace69344802173d4d Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaaafb3c3656f15921ce6e777fc38418b4772f4b Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-310.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7019841445a23b7aa7aac1de2094733e68ec77eb Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-38.pyc differ diff --git a/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..9863c663910f99a51d4a635ed50a241bce31f5d0 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -0,0 +1,811 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from .watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + + self.watermark = StableDiffusionXLWatermarker() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to + 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) + denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The + denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of + Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..a106ee7ac5f632aac8cddd93b31f8be30a767b29 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -0,0 +1,943 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from .watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + _optional_components = ["tokenizer", "text_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.watermark = StableDiffusionXLWatermarker() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = int(round(denoising_start * num_inference_steps)) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + strength: float = 0.3, + num_inference_steps: int = 50, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and + num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50) + denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed + that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly + beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as + detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to + 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) + denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca. + 30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it + only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + aesthetic_score (`float`, *optional*, defaults to 6.0): + TODO + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + TDOO + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + original_num_steps = num_inference_steps # save for denoising_start/end later + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device, denoising_start=denoising_start + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if denoising_start is None else False + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if denoising_end is not None and denoising_start is not None: + if denoising_start >= denoising_end: + raise ValueError( + f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}." + ) + + skipped_final_steps = int(round((1 - denoising_end) * original_num_steps)) + num_inference_steps = num_inference_steps - skipped_final_steps + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + elif denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..55f20660afc7f7556e469c1c193dcee1c763b0a3 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -0,0 +1,1209 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from .watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def mask_pil_to_torch(mask, height, width): + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask = torch.from_numpy(mask) + return mask + + +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + # checkpoint. TOD(Yiyi) - need to clean this up later + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + mask = mask_pil_to_torch(mask, height, width) + + if image.ndim == 3: + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + # if image.min() < -1 or image.max() > 1: + # raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = mask_pil_to_torch(mask, height, width) + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + if image.shape[1] == 4: + # images are in latent space and thus can't + # be masked set masked_image to None + # we assume that the checkpoint is not an inpainting + # checkpoint. TOD(Yiyi) - need to clean this up later + masked_image = None + else: + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionXLInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + _optional_components = ["tokenizer", "text_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.watermark = StableDiffusionXLWatermarker() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + masked_image_latents = None + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = int(round(denoising_start * num_inference_steps)) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and + num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50) + denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed + that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly + beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as + detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to + 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) + denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca. + 30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it + only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. set timesteps + original_num_steps = num_inference_steps # save for denoising_start/end later + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device, denoising_start=denoising_start + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if denoising_end is not None and denoising_start is not None: + if denoising_start >= denoising_end: + raise ValueError( + f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}." + ) + + skipped_final_steps = int(round((1 - denoising_end) * original_num_steps)) + num_inference_steps = num_inference_steps - skipped_final_steps + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + elif denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + return StableDiffusionXLPipelineOutput(images=latents) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/stable_diffusion_xl/watermark.py b/diffusers/pipelines/stable_diffusion_xl/watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6c9bf649b161fbc1ae7e59b3de6ba5c22884fa --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -0,0 +1,31 @@ +import numpy as np +import torch +from imwatermark import WatermarkEncoder + + +# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] + + +class StableDiffusionXLWatermarker: + def __init__(self): + self.watermark = WATERMARK_BITS + self.encoder = WatermarkEncoder() + + self.encoder.set_watermark("bits", self.watermark) + + def apply_watermark(self, images: torch.FloatTensor): + # can't encode images that are smaller than 256 + if images.shape[-1] < 256: + return images + + images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() + + images = [self.encoder.encode(image, "dwtDct") for image in images] + + images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) + + images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) + return images diff --git a/diffusers/pipelines/stochastic_karras_ve/__init__.py b/diffusers/pipelines/stochastic_karras_ve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a63c1d24afb2c4f36b0e284f0985a3ff508f4c7 --- /dev/null +++ b/diffusers/pipelines/stochastic_karras_ve/__init__.py @@ -0,0 +1 @@ +from .pipeline_stochastic_karras_ve import KarrasVePipeline diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b45bf0b5a92059d18ea3e387ad086adcb11530a1 Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41d7a3fa3ac852e688e83f93b802fd473ce8c5dd Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c58ff51d48526a06030863fa616796649355c1f0 Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc differ diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a757369efbbcf9927b89641a0a920c6ddd7f506 Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc differ diff --git a/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0ab15eb9758c42116cf67aab6d9d8a5a6dad7d --- /dev/null +++ b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,128 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import UNet2DModel +from ...schedulers import KarrasVeScheduler +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class KarrasVePipeline(DiffusionPipeline): + r""" + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`KarrasVeScheduler`]): + Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. + """ + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output.prev_sample, + step_output["derivative"], + ) + sample = step_output.prev_sample + + sample = (sample / 2 + 0.5).clamp(0, 1) + image = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/t2i_adapter/__init__.py b/diffusers/pipelines/t2i_adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4de661dbefab846cc72fa576675aad6cab1d134 --- /dev/null +++ b/diffusers/pipelines/t2i_adapter/__init__.py @@ -0,0 +1,14 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50233c532f5cd2f75ca4613ac7105f1cc8e95bb4 Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff267f1d748af7fc5eb7fa7c83b89a53cf55e497 Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-310.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b9047f536177ed3d9489970c134433453fd7c0b Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-310.pyc differ diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-38.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614a87a38195c4b67b9847885e4012fc4d4d0049 Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-38.pyc differ diff --git a/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..97c6fb99157f78ad83b426e5cba98855916859f4 --- /dev/null +++ b/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -0,0 +1,774 @@ +# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + BaseOutput, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +@dataclass +class StableDiffusionAdapterPipelineOutput(BaseOutput): + """ + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> from diffusers.utils import load_image + >>> import torch + >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter + + >>> image = load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png" + ... ) + + >>> color_palette = image.resize((8, 8)) + >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) + + >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16) + >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... adapter=adapter, + ... torch_dtype=torch.float16, + ... ) + + >>> pipe.to("cuda") + + >>> out_image = pipe( + ... "At night, glowing cubes in front of the beach", + ... image=color_palette, + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" + ) + return image + + +class StableDiffusionAdapterPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://arxiv.org/abs/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + adapter_weights: Optional[List[float]] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(adapter, (list, tuple)): + adapter = MultiAdapter(adapter, adapter_weights=adapter_weights) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + adapter=adapter, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.adapter, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.total_downscale_factor` + height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.total_downscale_factor` + width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + + return height, width + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + is_multi_adapter = isinstance(self.adapter, MultiAdapter) + if is_multi_adapter: + adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image] + n, c, h, w = adapter_input[0].shape + adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input]) + else: + adapter_input = _preprocess_adapter_image(image, height, width).to(device) + adapter_input = adapter_input.to(self.adapter.dtype) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=[state.clone() for state in adapter_state], + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/text_to_video_synthesis/__init__.py b/diffusers/pipelines/text_to_video_synthesis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d70c1c2ea2a8af8d69aebb915c9d6eacc52c14f8 --- /dev/null +++ b/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +@dataclass +class TextToVideoSDPipelineOutput(BaseOutput): + """ + Output class for text to video pipelines. + + Args: + frames (`List[np.ndarray]` or `torch.FloatTensor`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as + a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list + denotes the video length i.e., the number of frames. + """ + + frames: Union[List[np.ndarray], torch.FloatTensor] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_text_to_video_synth import TextToVideoSDPipeline + from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401 + from .pipeline_text_to_video_zero import TextToVideoZeroPipeline diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebd8a72dcf14b5cacfbaea569b5b26f719649559 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b12a4dda8aaa4e029d9702dfe03b073626cfa118 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ba4a8d74f18ecbb3c873a45ce0a26eec2374f91 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aa7cc5ce96ad129c3f96f363f5dcbfb597b8660 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-38.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4020b8e75dcc1cb257b638d91af03639a492af03 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-310.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef7227c624de539ebc347ff50efaa76660049946 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-38.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd87e8e7a841f0aef79dc2e24e5fa27ede71d228 Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-310.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cd0482c1785c7aca6631edbbfbd9eea62d3021e Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-38.pyc differ diff --git a/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..dad7d563989255efd589c65034397d5f1a33b5e4 --- /dev/null +++ b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -0,0 +1,653 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import TextToVideoSDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..72a5b762d5048928a5bd5295acc303a60e5edf18 --- /dev/null +++ b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -0,0 +1,731 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import TextToVideoSDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + >>> from diffusers.utils import export_to_video + + >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> prompt = "spiderman running in the desert" + >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames + >>> # safe low-res video + >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") + + >>> # let's offload the text-to-image model + >>> pipe.to("cpu") + + >>> # and load the image-to-image model + >>> pipe = DiffusionPipeline.from_pretrained( + ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" + ... ) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode + >>> pipe.vae.enable_slicing() + + >>> # now let's upscale it + >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] + + >>> # and denoise it + >>> video_frames = pipe(prompt, video=video, strength=0.6).frames + >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +def preprocess_video(video): + supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) + + if isinstance(video, supported_formats): + video = [video] + elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(video[0], PIL.Image.Image): + video = [np.array(frame) for frame in video] + + if isinstance(video[0], np.ndarray): + video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) + + if video.dtype == np.uint8: + video = np.array(video).astype(np.float32) / 255.0 + + if video.ndim == 4: + video = video[None, ...] + + video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) + + elif isinstance(video[0], torch.Tensor): + video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) + + # don't need any preprocess if the video is latents + channel = video.shape[1] + if channel == 4: + return video + + # move channels before num_frames + video = video.permute(0, 2, 1, 3, 4) + + # normalize video + video = 2.0 * video - 1.0 + + return video + + +class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.vae, self.unet]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): + video = video.to(device=device, dtype=dtype) + + # change from (b, c, f, h, w) -> (b * f, c, w, h) + bsz, channel, frames, width, height = video.shape + video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + if video.shape[1] == 4: + init_latents = video + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(video).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + video: Union[List[np.ndarray], torch.FloatTensor] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + guidance_scale: float = 15.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + video: (`List[np.ndarray]` or `torch.FloatTensor`): + `video` frames or tensor representing a video batch, that will be used as the starting point for the + process. Can also accpet video latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess video + video = preprocess_video(video) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return TextToVideoSDPipelineOutput(frames=latents) + + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7207f904f08032c3f125d64bf5f024a6b89b60 --- /dev/null +++ b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -0,0 +1,627 @@ +import copy +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from torch.nn.functional import grid_sample +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import BaseOutput + + +def rearrange_0(tensor, f): + F, C, H, W = tensor.size() + tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) + return tensor + + +def rearrange_1(tensor): + B, C, F, H, W = tensor.size() + return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) + + +def rearrange_3(tensor, f): + F, D, C = tensor.size() + return torch.reshape(tensor, (F // f, f, D, C)) + + +def rearrange_4(tensor): + B, F, D, C = tensor.size() + return torch.reshape(tensor, (B * F, D, C)) + + +class CrossFrameAttnProcessor: + """ + Cross frame attention processor. Each frame attends the first frame. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CrossFrameAttnProcessor2_0: + """ + Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +@dataclass +class TextToVideoPipelineOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +def coords_grid(batch, ht, wd, device): + # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def warp_single_latent(latent, reference_flow): + """ + Warp latent of a single frame with given flow + + Args: + latent: latent code of a single frame + reference_flow: flow which to warp the latent with + + Returns: + warped: warped latent + """ + _, _, H, W = reference_flow.size() + _, _, h, w = latent.size() + coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) + + coords_t0 = coords0 + reference_flow + coords_t0[:, 0] /= W + coords_t0[:, 1] /= H + + coords_t0 = coords_t0 * 2.0 - 1.0 + coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") + coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) + + warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") + return warped + + +def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): + """ + Create translation motion field + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + device: device + dtype: dtype + + Returns: + + """ + seq_length = len(frame_ids) + reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) + for fr_idx in range(seq_length): + reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) + reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) + return reference_flow + + +def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): + """ + Creates translation motion and warps the latents accordingly + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + latents: latent codes of frames + + Returns: + warped_latents: warped latents + """ + motion_field = create_motion_field( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + frame_ids=frame_ids, + device=latents.device, + dtype=latents.dtype, + ) + warped_latents = latents.clone().detach() + for i in range(len(warped_latents)): + warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) + return warped_latents + + +class TextToVideoZeroPipeline(StableDiffusionPipeline): + r""" + Pipeline for zero-shot text-to-video generation using Stable Diffusion. + + This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods + the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + processor = ( + CrossFrameAttnProcessor2_0(batch_size=2) + if hasattr(F, "scaled_dot_product_attention") + else CrossFrameAttnProcessor(batch_size=2) + ) + self.unet.set_attn_processor(processor) + + def forward_loop(self, x_t0, t0, t1, generator): + """ + Perform ddpm forward process from time t0 to t1. This is the same as adding noise with corresponding variance. + + Args: + x_t0: latent code at time t0 + t0: t0 + t1: t1 + generator: torch.Generator object + + Returns: + x_t1: forward process applied to x_t0 from time t0 to t1. + """ + eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) + alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) + x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps + return x_t1 + + def backward_loop( + self, + latents, + timesteps, + prompt_embeds, + guidance_scale, + callback, + callback_steps, + num_warmup_steps, + extra_step_kwargs, + cross_attention_kwargs=None, + ): + """ + Perform backward process given list of time steps + + Args: + latents: Latents at time timesteps[0]. + timesteps: time steps, along which to perform backward process. + prompt_embeds: Pre-generated text embeddings + guidance_scale: + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + extra_step_kwargs: extra_step_kwargs. + cross_attention_kwargs: cross_attention_kwargs. + num_warmup_steps: number of warmup steps. + + Returns: + latents: latents of backward process output at time timesteps[-1] + """ + do_classifier_free_guidance = guidance_scale > 1.0 + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + with self.progress_bar(total=num_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + return latents.clone().detach() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int] = 8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + motion_field_strength_x: float = 12, + motion_field_strength_y: float = 12, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + t0: int = 44, + t1: int = 47, + frame_ids: Optional[List[int]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + video_length (`int`, *optional*, defaults to 8): The number of generated video frames + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"numpy"`): + The output format of the generated image. Choose between `"latent"` and `"numpy"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + motion_field_strength_x (`float`, *optional*, defaults to 12): + Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + motion_field_strength_y (`float`, *optional*, defaults to 12): + Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + t0 (`int`, *optional*, defaults to 44): + Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + t1 (`int`, *optional*, defaults to 47): + Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + frame_ids (`List[int]`, *optional*): + Indexes of the frames that are being generated. This is used when generating longer videos + chunk-by-chunk. + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]: + The output contains a ndarray of the generated images, when output_type != 'latent', otherwise a latent + codes of generated image, and a list of `bool`s denoting whether the corresponding generated image + likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + assert video_length > 0 + if frame_ids is None: + frame_ids = list(range(video_length)) + assert len(frame_ids) == video_length + + assert num_videos_per_prompt == 1 + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Perform the first backward process up to time T_1 + x_1_t1 = self.backward_loop( + timesteps=timesteps[: -t1 - 1], + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=num_warmup_steps, + ) + scheduler_copy = copy.deepcopy(self.scheduler) + + # Perform the second backward process up to time T_0 + x_1_t0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 : -t0 - 1], + prompt_embeds=prompt_embeds, + latents=x_1_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + + # Propagate first frame latents at time T_0 to remaining frames + x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) + + # Add motion in latents at time T_0 + x_2k_t0 = create_motion_field_and_warp_latents( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + latents=x_2k_t0, + frame_ids=frame_ids[1:], + ) + + # Perform forward process up to time T_1 + x_2k_t1 = self.forward_loop( + x_t0=x_2k_t0, + t0=timesteps[-t0 - 1].item(), + t1=timesteps[-t1 - 1].item(), + generator=generator, + ) + + # Perform backward process from time T_1 to 0 + x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) + b, l, d = prompt_embeds.size() + prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) + + self.scheduler = scheduler_copy + x_1k_0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 :], + prompt_embeds=prompt_embeds, + latents=x_1k_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + latents = x_1k_0 + + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + image = self.decode_latents(latents) + # Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/unclip/__init__.py b/diffusers/pipelines/unclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..075e66bb680aca294b36aa7ad0abb8d0f651cd92 --- /dev/null +++ b/diffusers/pipelines/unclip/__init__.py @@ -0,0 +1,17 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline +else: + from .pipeline_unclip import UnCLIPPipeline + from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline + from .text_proj import UnCLIPTextProjModel diff --git a/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2085b76dfabfae74a7c9cb418296e470b1af24a Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2ede158e4989d91231581c4f453c6e6f4f42d3 Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..507fdd91c7f7a619e99a6c49f5642f56af39255e Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d93a6a4beda48656b367e5182f6792c218ec9ae9 Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-38.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be06f106c2689b68daa661ebcd28dbb563fac5c6 Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8300f16906ccd5eae78cb8d2e2a4a7732ac69d4e Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-38.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbd8e45c4a4d3661fadfe6e30f6da8de61380dfa Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc differ diff --git a/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17d05790af5ac31577478551db9394f1203ebd8f Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-38.pyc differ diff --git a/diffusers/pipelines/unclip/pipeline_unclip.py b/diffusers/pipelines/unclip/pipeline_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b67bebfc8dc9bbc530fdc1f283bd356662cd2a --- /dev/null +++ b/diffusers/pipelines/unclip/pipeline_unclip.py @@ -0,0 +1,493 @@ +# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import functional as F +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import UnCLIPScheduler +from ...utils import logging, randn_tensor +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process. Just a modified DDPMScheduler. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + + """ + + _exclude_from_cpu_offload = ["prior"] + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..580417f517becdc159b727ea6e59d3430c6b1076 --- /dev/null +++ b/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -0,0 +1,420 @@ +# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import PIL +import torch +from torch.nn import functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import UNet2DConditionModel, UNet2DModel +from ...pipelines import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import UnCLIPScheduler +from ...utils import logging, randn_tensor +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPImageVariationPipeline(DiffusionPipeline): + """ + Pipeline to generate variations from an input image using unCLIP + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + + """ + + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): + dtype = next(self.image_encoder.parameters()).dtype + + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + + image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, + num_images_per_prompt: int = 1, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPImageProcessor`. Can be left to `None` only when `image_embeddings` are passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can the be left to `None`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + else: + batch_size = image_embeddings.shape[0] + + prompt = [""] * batch_size + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) + + # decoder + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/unclip/text_proj.py b/diffusers/pipelines/unclip/text_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..0414559500c16484dd326f72d04a5306dc14682e --- /dev/null +++ b/diffusers/pipelines/unclip/text_proj.py @@ -0,0 +1,86 @@ +# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + + self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) + + # parameters for additional clip time embeddings + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) + + # parameters for encoder hidden states + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): + if do_classifier_free_guidance: + # Add the classifier free guidance embeddings to the image embeddings + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) + + # The image embeddings batch size and the text embeddings batch size are equal + assert image_embeddings.shape[0] == prompt_embeds.shape[0] + + batch_size = prompt_embeds.shape[0] + + # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and + # adding CLIP embeddings to the existing timestep embedding, ... + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) + time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds + + # ... and by projecting CLIP embeddings into four + # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" + clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) + clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) + + text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) + + return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/diffusers/pipelines/unidiffuser/__init__.py b/diffusers/pipelines/unidiffuser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a774e3274030153d20618024b8c2bc6385ef367a --- /dev/null +++ b/diffusers/pipelines/unidiffuser/__init__.py @@ -0,0 +1,20 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) +else: + from .modeling_text_decoder import UniDiffuserTextDecoder + from .modeling_uvit import UniDiffuserModel, UTransformer2DModel + from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline diff --git a/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44c45f26d08a473339348ec6f11f058e7fa3c677 Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0061d9c3db9693f42a3f807619530a4015de8e Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1111fbb6be89b16a0161a8818577e16accc10b1f Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-310.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4677e78df80be1364c77474856d2078b80767a3 Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-38.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bb588b0d2d9894ea55ca1e65bda09510f140be9 Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-310.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1323a127f051c964cb645b1f00499f5125f9914 Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-38.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d445d9019596de4921230134021c146b55e2f882 Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-310.pyc differ diff --git a/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0af0456a58a8db7011979345d2ef9a452fccd87 Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-38.pyc differ diff --git a/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/diffusers/pipelines/unidiffuser/modeling_text_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9b962f6e065621c8fc83775f555bbd732ccc8a26 --- /dev/null +++ b/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -0,0 +1,296 @@ +from typing import Optional + +import numpy as np +import torch +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers.modeling_utils import ModuleUtilsMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + """ + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + prefix_length (`int`): + Max number of prefix tokens that will be supplied to the model. + prefix_inner_dim (`int`): + The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + CLIP text encoder. + prefix_hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + """ + + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] + + @register_to_config + def __init__( + self, + prefix_length: int, + prefix_inner_dim: int, + prefix_hidden_dim: Optional[int] = None, + vocab_size: int = 50257, # Start of GPT2 config args + n_positions: int = 1024, + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + n_inner: Optional[int] = None, + activation_function: str = "gelu_new", + resid_pdrop: float = 0.1, + embd_pdrop: float = 0.1, + attn_pdrop: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, + scale_attn_weights: bool = True, + use_cache: bool = True, + scale_attn_by_inverse_layer_idx: bool = False, + reorder_and_upcast_attn: bool = False, + ): + super().__init__() + + self.prefix_length = prefix_length + + if prefix_inner_dim != n_embd and prefix_hidden_dim is None: + raise ValueError( + f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" + f" `n_embd`: {n_embd} are not equal." + ) + + self.prefix_inner_dim = prefix_inner_dim + self.prefix_hidden_dim = prefix_hidden_dim + + self.encode_prefix = ( + nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) + if self.prefix_hidden_dim is not None + else nn.Identity() + ) + self.decode_prefix = ( + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + ) + + gpt_config = GPT2Config( + vocab_size=vocab_size, + n_positions=n_positions, + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, + activation_function=activation_function, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attn_pdrop=attn_pdrop, + layer_norm_epsilon=layer_norm_epsilon, + initializer_range=initializer_range, + scale_attn_weights=scale_attn_weights, + use_cache=use_cache, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + self.transformer = GPT2LMHeadModel(gpt_config) + + def forward( + self, + input_ids: torch.Tensor, + prefix_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + """ + Args: + input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): + Text tokens to use for inference. + prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): + Prefix embedding to preprend to the embedded tokens. + attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): + Attention mask for the prefix embedding. + labels (`torch.Tensor`, *optional*): + Labels to use for language modeling. + """ + embedding_text = self.transformer.transformer.wte(input_ids) + hidden = self.encode_prefix(prefix_embeds) + prefix_embeds = self.decode_prefix(hidden) + embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) + + if labels is not None: + dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) + labels = torch.cat((dummy_token, input_ids), dim=1) + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) + if self.prefix_hidden_dim is not None: + return out, hidden + else: + return out + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def encode(self, prefix): + return self.encode_prefix(prefix) + + @torch.no_grad() + def generate_captions(self, features, eos_token_id, device): + """ + Generate captions given text embedding features. Returns list[L]. + + Args: + features (`torch.Tensor` of shape `(B, L, D)`): + Text embedding features to generate captions from. + eos_token_id (`int`): + The token ID of the EOS token for the text decoder model. + device: + Device to perform text generation on. + + Returns: + `List[str]`: A list of strings generated from the decoder model. + """ + + features = torch.split(features, 1, dim=0) + generated_tokens = [] + generated_seq_lengths = [] + for feature in features: + feature = self.decode_prefix(feature.to(device)) # back to the clip feature + # Only support beam search for now + output_tokens, seq_lengths = self.generate_beam( + input_embeds=feature, device=device, eos_token_id=eos_token_id + ) + generated_tokens.append(output_tokens[0]) + generated_seq_lengths.append(seq_lengths[0]) + generated_tokens = torch.stack(generated_tokens) + generated_seq_lengths = torch.stack(generated_seq_lengths) + return generated_tokens, generated_seq_lengths + + @torch.no_grad() + def generate_beam( + self, + input_ids=None, + input_embeds=None, + device=None, + beam_size: int = 5, + entry_length: int = 67, + temperature: float = 1.0, + eos_token_id: Optional[int] = None, + ): + """ + Generates text using the given tokenizer and text prompt or token embedding via beam search. This + implementation is based on the beam search implementation from the [original UniDiffuser + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). + + Args: + eos_token_id (`int`, *optional*): + The token ID of the EOS token for the text decoder model. + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` + must be supplied. + input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + An embedded representation to directly pass to the transformer as a prefix for beam search. One of + `input_ids` and `input_embeds` must be supplied. + device: + The device to perform beam search on. + beam_size (`int`, *optional*, defaults to `5`): + The number of best states to store during beam search. + entry_length (`int`, *optional*, defaults to `67`): + The number of iterations to run beam search. + temperature (`float`, *optional*, defaults to 1.0): + The temperature to use when performing the softmax over logits from the decoding model. + + Returns: + `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated + token sequences sorted by score in descending order, and the second element is the sequence lengths + corresponding to those sequences. + """ + # Generates text until stop_token is reached using beam search with the desired beam size. + stop_token_index = eos_token_id + tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + + if input_embeds is not None: + generated = input_embeds + else: + generated = self.transformer.transformer.wte(input_ids) + + for i in range(entry_length): + outputs = self.transformer(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + + scores = scores / seq_lengths + order = scores.argsort(descending=True) + # tokens tensors are already padded to max_seq_length + output_texts = [tokens[i] for i in order] + output_texts = torch.stack(output_texts, dim=0) + seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) + return output_texts, seq_lengths diff --git a/diffusers/pipelines/unidiffuser/modeling_uvit.py b/diffusers/pipelines/unidiffuser/modeling_uvit.py new file mode 100644 index 0000000000000000000000000000000000000000..b7829f76ec12f946490618e0d03857777efdf219 --- /dev/null +++ b/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -0,0 +1,1196 @@ +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.attention import AdaLayerNorm, FeedForward +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed +from ...models.transformer_2d import Transformer2DModelOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect." + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, + \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for + generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + use_pos_embed=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.use_pos_embed = use_pos_embed + if self.use_pos_embed: + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.use_pos_embed: + return latent + self.pos_embed + else: + return latent + + +class SkipBlock(nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.skip_linear = nn.Linear(2 * dim, dim) + + # Use torch.nn.LayerNorm for now, following the original code + self.norm = nn.LayerNorm(dim) + + def forward(self, x, skip): + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = self.norm(x) + + return x + + +# Modified to support both pre-LayerNorm and post-LayerNorm configurations +# Don't support AdaLayerNormZero for now +# Modified from diffusers.models.attention.BasicTransformerBlock +class UTransformerBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. + `pre_layer_norm = True`. + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = True, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + else: + norm_hidden_states = self.norm1(hidden_states) + else: + norm_hidden_states = hidden_states + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + attn_output = self.norm1(attn_output, timestep) + else: + attn_output = self.norm1(attn_output) + + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + else: + norm_hidden_states = hidden_states + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = self.norm3(hidden_states) + else: + norm_hidden_states = hidden_states + + ff_output = self.ff(norm_hidden_states) + + # Post-LayerNorm + if not self.pre_layer_norm: + ff_output = self.norm3(ff_output) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +# Like UTransformerBlock except with LayerNorms on the residual backbone of the block +# Modified from diffusers.models.attention.BasicTransformerBlock +class UniDiffuserBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the + LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser + implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = False, + final_dropout: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + ff_output = self.ff(hidden_states) + + hidden_states = ff_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + return hidden_states + + +# Modified from diffusers.models.transformer_2d.Transformer2DModel +# Modify the transformer block structure to be U-Net like following U-ViT +# Only supports patch-style input and torch.nn.LayerNorm currently +# https://github.com/baofff/U-ViT +class UTransformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared + to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, + similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] + layer and then reshaped to (b, t, d). + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = 2, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Input + # Only support patch input of shape (batch_size, num_channels, height, width) for now + assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." + + assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" + + # 2. Define input layers + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + + # 3. Define transformers blocks + # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, + # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in + # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). + # Quick hack to make the transformer block type configurable + if block_type == "unidiffuser": + block_cls = UniDiffuserBlock + else: + block_cls = UTransformerBlock + self.transformer_in_blocks = nn.ModuleList( + [ + block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + for d in range(num_layers // 2) + ] + ) + + self.transformer_mid_block = block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + + # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs + # before each transformer out_block. + self.transformer_out_blocks = nn.ModuleList( + [ + nn.ModuleDict( + { + "skip": SkipBlock( + inner_dim, + ), + "block": block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ), + } + ) + for d in range(num_layers // 2) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + + # Following the UniDiffuser U-ViT implementation, we process the transformer output with + # a LayerNorm layer with per-element affine params + self.norm_out = nn.LayerNorm(inner_dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + hidden_states_is_embedding: bool = False, + unpatchify: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): + Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will + ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the + transformer blocks. + unpatchify (`bool`, *optional*, defaults to `True`): + Whether to unpatchify the transformer output. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. Check inputs + + if not unpatchify and return_dict: + raise ValueError( + f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" + f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" + " rather than (batch_size, num_channels, height, width)." + ) + + # 1. Input + if not hidden_states_is_embedding: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + + # In ("downsample") blocks + skips = [] + for in_block in self.transformer_in_blocks: + hidden_states = in_block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + skips.append(hidden_states) + + # Mid block + hidden_states = self.transformer_mid_block(hidden_states) + + # Out ("upsample") blocks + for out_block in self.transformer_out_blocks: + hidden_states = out_block["skip"](hidden_states, skips.pop()) + hidden_states = out_block["block"]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic + hidden_states = self.norm_out(hidden_states) + # hidden_states = self.proj_out(hidden_states) + + if unpatchify: + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + else: + output = hidden_states + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UniDiffuserModel(ModelMixin, ConfigMixin): + """ + Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a + modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the + CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). + + Parameters: + text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. + clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + ff_final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + use_data_type_embedding (`bool`, *optional*): + Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 + is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` + argument, which can either be `1` to use the weights trained on non-publically-available data or `0` + otherwise. This argument is subsequently embedded by the data type embedding, if used. + """ + + @register_to_config + def __init__( + self, + text_dim: int = 768, + clip_img_dim: int = 512, + num_text_tokens: int = 77, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + use_timestep_embedding=False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = True, + use_data_type_embedding: bool = False, + ): + super().__init__() + + # 0. Handle dimensions + self.inner_dim = num_attention_heads * attention_head_dim + + assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.patch_size = patch_size + # Assume image is square... + self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) + + # 1. Define input layers + # 1.1 Input layers for text and image input + # For now, only support patch input for VAE latent image input + self.vae_img_in = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) + self.text_in = nn.Linear(text_dim, self.inner_dim) + + # 1.2. Timestep embeddings for t_img, t_text + self.timestep_img_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_img_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + self.timestep_text_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_text_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + # 1.3. Positional embedding + self.num_text_tokens = num_text_tokens + self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) + self.pos_embed_drop = nn.Dropout(p=dropout) + trunc_normal_(self.pos_embed, std=0.02) + + # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary + self.use_data_type_embedding = use_data_type_embedding + if self.use_data_type_embedding: + self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) + self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) + + # 2. Define transformer blocks + self.transformer = UTransformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + block_type=block_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + use_patch_pos_embed=use_patch_pos_embed, + ff_final_dropout=ff_final_dropout, + ) + + # 3. Define output layers + patch_dim = (patch_size**2) * out_channels + self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) + self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) + self.text_out = nn.Linear(self.inner_dim, text_dim) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def forward( + self, + latent_image_embeds: torch.FloatTensor, + image_embeds: torch.FloatTensor, + prompt_embeds: torch.FloatTensor, + timestep_img: Union[torch.Tensor, float, int], + timestep_text: Union[torch.Tensor, float, int], + data_type: Optional[Union[torch.Tensor, float, int]] = 1, + encoder_hidden_states=None, + cross_attention_kwargs=None, + ): + """ + Args: + latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): + Latent image representation from the VAE encoder. + image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): + CLIP-embedded image representation (unsqueezed in the first dimension). + prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): + CLIP-embedded text representation. + timestep_img (`torch.long` or `float` or `int`): + Current denoising step for the image. + timestep_text (`torch.long` or `float` or `int`): + Current denoising step for the text. + data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): + Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, + or `0` otherwise. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + + + Returns: + `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE + image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text + embedding. + """ + batch_size = latent_image_embeds.shape[0] + + # 1. Input + # 1.1. Map inputs to shape (B, N, inner_dim) + vae_hidden_states = self.vae_img_in(latent_image_embeds) + clip_hidden_states = self.clip_img_in(image_embeds) + text_hidden_states = self.text_in(prompt_embeds) + + num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) + + # 1.2. Encode image timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_img): + timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) + + timestep_img_token = self.timestep_img_proj(timestep_img) + # t_img_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_img_token = timestep_img_token.to(dtype=self.dtype) + timestep_img_token = self.timestep_img_embed(timestep_img_token) + timestep_img_token = timestep_img_token.unsqueeze(dim=1) + + # 1.3. Encode text timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_text): + timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) + + timestep_text_token = self.timestep_text_proj(timestep_text) + # t_text_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_text_token = timestep_text_token.to(dtype=self.dtype) + timestep_text_token = self.timestep_text_embed(timestep_text_token) + timestep_text_token = timestep_text_token.unsqueeze(dim=1) + + # 1.4. Concatenate all of the embeddings together. + if self.use_data_type_embedding: + assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" + if not torch.is_tensor(data_type): + data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) + + data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) + hidden_states = torch.cat( + [ + timestep_img_token, + timestep_text_token, + data_type_token, + text_hidden_states, + clip_hidden_states, + vae_hidden_states, + ], + dim=1, + ) + else: + hidden_states = torch.cat( + [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], + dim=1, + ) + + # 1.5. Prepare the positional embeddings and add to hidden states + # Note: I think img_vae should always have the proper shape, so there's no need to interpolate + # the position embeddings. + if self.use_data_type_embedding: + pos_embed = torch.cat( + [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 + ) + else: + pos_embed = self.pos_embed + hidden_states = hidden_states + pos_embed + hidden_states = self.pos_embed_drop(hidden_states) + + # 2. Blocks + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=None, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + hidden_states_is_embedding=True, + unpatchify=False, + )[0] + + # 3. Output + # Split out the predicted noise representation. + if self.use_data_type_embedding: + ( + t_img_token_out, + t_text_token_out, + data_type_token_out, + text_out, + img_clip_out, + img_vae_out, + ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) + else: + t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( + (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 + ) + + img_vae_out = self.vae_img_out(img_vae_out) + + # unpatchify + height = width = int(img_vae_out.shape[1] ** 0.5) + img_vae_out = img_vae_out.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) + img_vae_out = img_vae_out.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + img_clip_out = self.clip_img_out(img_clip_out) + + text_out = self.text_out(text_out) + + return img_vae_out, img_clip_out, text_out diff --git a/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py new file mode 100644 index 0000000000000000000000000000000000000000..3632d74d1c12d2d0aab4bff2ad812e163ea48ee1 --- /dev/null +++ b/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -0,0 +1,1382 @@ +import inspect +import warnings +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from ...models import AutoencoderKL +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ...utils.outputs import BaseOutput +from ..pipeline_utils import DiffusionPipeline +from .modeling_text_decoder import UniDiffuserTextDecoder +from .modeling_uvit import UniDiffuserModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# New BaseOutput child class for joint image-text output +@dataclass +class ImageTextPipelineOutput(BaseOutput): + """ + Output class for joint image-text pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + text (`List[str]` or `List[List[str]]`) + List of generated text strings of length `batch_size` or a list of list of strings whose outer list has + length `batch_size`. + """ + + images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + text: Optional[Union[List[str], List[List[str]]]] + + +class UniDiffuserPipeline(DiffusionPipeline): + r""" + Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports + unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and + joint image-text generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This + is part of the UniDiffuser image representation, along with the CLIP vision encoding. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text + prompts. + image_encoder ([`CLIPVisionModel`]): + UniDiffuser uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode + images as part of its image representation, along with the VAE latent representation. + image_processor ([`CLIPImageProcessor`]): + CLIP image processor of class + [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), + used to preprocess the image before CLIP encoding it with `image_encoder`. + clip_tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which + is used to tokenizer a prompt before encoding it with `text_encoder`. + text_decoder ([`UniDiffuserTextDecoder`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + text_tokenizer ([`GPT2Tokenizer`]): + Tokenizer of class + [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which + is used along with the `text_decoder` to decode text for text generation. + unet ([`UniDiffuserModel`]): + UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a + [`Transformer2DModel`] with U-Net-style skip connections between transformer layers. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The + original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModelWithProjection, + image_processor: CLIPImageProcessor, + clip_tokenizer: CLIPTokenizer, + text_decoder: UniDiffuserTextDecoder, + text_tokenizer: GPT2Tokenizer, + unet: UniDiffuserModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: + raise ValueError( + f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" + f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.num_channels_latents = vae.config.latent_channels + self.text_encoder_seq_len = text_encoder.config.max_position_embeddings + self.text_encoder_hidden_size = text_encoder.config.hidden_size + self.image_encoder_projection_dim = image_encoder.config.projection_dim + self.unet_resolution = unet.config.sample_size + + self.text_intermediate_dim = self.text_encoder_hidden_size + if self.text_decoder.prefix_hidden_dim is not None: + self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim + + self.mode = None + + # TODO: handle safety checking? + self.safety_checker = None + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): + r""" + Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set + mode will be used. + """ + prompt_available = (prompt is not None) or (prompt_embeds is not None) + image_available = image is not None + input_available = prompt_available or image_available + + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + full_latents_available = latents is not None + image_latents_available = vae_latents_available and clip_latents_available + all_indv_latents_available = prompt_latents_available and image_latents_available + + if self.mode is not None: + # Preferentially use the mode set by the user + mode = self.mode + elif prompt_available: + mode = "text2img" + elif image_available: + mode = "img2text" + else: + # Neither prompt nor image supplied, infer based on availability of latents + if full_latents_available or all_indv_latents_available: + mode = "joint" + elif prompt_latents_available: + mode = "text" + elif image_latents_available: + mode = "img" + else: + # No inputs or latents available + mode = "joint" + + # Give warnings for ambiguous cases + if self.mode is None and prompt_available and image_available: + logger.warning( + f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," + f" defaulting to mode '{mode}'." + ) + + if self.mode is None and not input_available: + if vae_latents_available != clip_latents_available: + # Exactly one of vae_latents and clip_latents is supplied + logger.warning( + f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" + f" are expected to be supplied. Defaulting to mode '{mode}'." + ) + elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: + # No inputs or latents supplied + logger.warning( + f"No inputs or latents have been supplied, and mode has not been manually set," + f" defaulting to mode '{mode}'." + ) + + return mode + + # Functions to manually set the mode + def set_text_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") text generation.""" + self.mode = "text" + + def set_image_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") image generation.""" + self.mode = "img" + + def set_text_to_image_mode(self): + r"""Manually set the generation mode to text-conditioned image generation.""" + self.mode = "text2img" + + def set_image_to_text_mode(self): + r"""Manually set the generation mode to image-conditioned text generation.""" + self.mode = "img2text" + + def set_joint_mode(self): + r"""Manually set the generation mode to unconditional joint image-text generation.""" + self.mode = "joint" + + def reset_mode(self): + r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" + self.mode = None + + def _infer_batch_size( + self, + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ): + r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" + if num_images_per_prompt is None: + num_images_per_prompt = 1 + if num_prompts_per_image is None: + num_prompts_per_image = 1 + + assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" + assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" + + if mode in ["text2img"]: + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + # Either prompt or prompt_embeds must be present for text2img. + batch_size = prompt_embeds.shape[0] + multiplier = num_images_per_prompt + elif mode in ["img2text"]: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + else: + # Image must be available and type either PIL.Image.Image or torch.FloatTensor. + # Not currently supporting something like image_embeds. + batch_size = image.shape[0] + multiplier = num_prompts_per_image + elif mode in ["img"]: + if vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + multiplier = num_images_per_prompt + elif mode in ["text"]: + if prompt_latents is not None: + batch_size = prompt_latents.shape[0] + else: + batch_size = 1 + multiplier = num_prompts_per_image + elif mode in ["joint"]: + if latents is not None: + batch_size = latents.shape[0] + elif prompt_latents is not None: + batch_size = prompt_latents.shape[0] + elif vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + + if num_images_per_prompt == num_prompts_per_image: + multiplier = num_images_per_prompt + else: + multiplier = min(num_images_per_prompt, num_prompts_per_image) + logger.warning( + f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" + f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" + f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." + ) + return batch_size, multiplier + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + # self.tokenizer => self.clip_tokenizer + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.clip_tokenizer( + prompt, + padding="max_length", + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.clip_tokenizer.batch_decode( + untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.clip_tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + # Add num_prompts_per_image argument, sample from autoencoder moment distribution + def encode_image_vae_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + do_classifier_free_guidance, + generator=None, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + * self.vae.config.scaling_factor + for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + # Scale image_latents by the VAE's scaling factor + image_latents = image_latents * self.vae.config.scaling_factor + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def encode_image_clip_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + generator=None, + ): + # Map image to CLIP embedding. + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + preprocessed_image = self.image_processor.preprocess( + image, + return_tensors="pt", + ) + preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list): + image_latents = [ + self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.image_encoder(**preprocessed_image).image_embeds + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return image_latents + + # Note that the CLIP latents are not decoded for image generation. + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + # Rename: decode_latents -> decode_image_latents + def decode_image_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_text_latents( + self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded prompt. + shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shace (B, L, D) + latents = latents.repeat(num_images_per_prompt, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. + def prepare_image_vae_latents( + self, + batch_size, + num_prompts_per_image, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size * num_prompts_per_image, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, C, H, W) + latents = latents.repeat(num_prompts_per_image, 1, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_clip_latents( + self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded image. + shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, L, D) + latents = latents.repeat(num_prompts_per_image, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _split(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) + and (B, 1, clip_img_dim) + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + + img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + return img_vae, img_clip + + def _combine(self, img_vae, img_clip): + r""" + Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, + clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + return torch.concat([img_vae, img_clip], dim=-1) + + def _split_joint(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, + img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is + of shape (B, text_seq_len, text_dim). + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_intermediate_dim + + img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) + return img_vae, img_clip, text + + def _combine_joint(self, img_vae, img_clip, text): + r""" + Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, + clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, + C * H * W + L_img * clip_img_dim + L_text * text_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + text = torch.reshape(text, (text.shape[0], -1)) + return torch.concat([img_vae, img_clip, text], dim=-1) + + def _get_noise_pred( + self, + mode, + latents, + t, + prompt_embeds, + img_vae, + img_clip, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ): + r""" + Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. + """ + if mode == "joint": + # Joint text-image generation + img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type + ) + + x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) + + if guidance_scale <= 1.0: + return x_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + _, _, text_out_uncond = self.unet( + img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) + + return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond + elif mode == "text2img": + # Text-conditioned image generation + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type + ) + + img_out = self._combine(img_vae_out, img_clip_out) + + if guidance_scale <= 1.0: + return img_out + + # Classifier-free guidance + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) + + return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond + elif mode == "img2text": + # Image-conditioned text generation + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type + ) + + if guidance_scale <= 1.0: + return text_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond + elif mode == "text": + # Unconditional ("marginal") text generation (no CFG) + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return text_out + elif mode == "img": + # Unconditional ("marginal") image generation (no CFG) + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, + img_clip_latents, + prompt_embeds, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out = self._combine(img_vae_out, img_clip_out) + return img_out + + def check_latents_shape(self, latents_name, latents, expected_shape): + latents_shape = latents.shape + expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension + expected_shape_str = ", ".join(str(dim) for dim in expected_shape) + if len(latents_shape) != expected_num_dims: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {len(latents_shape)} dimensions." + ) + for i in range(1, expected_num_dims): + if latents_shape[i] != expected_shape[i - 1]: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." + ) + + def check_inputs( + self, + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + latents=None, + prompt_latents=None, + vae_latents=None, + clip_latents=None, + ): + # Check inputs before running the generative process. + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if mode == "text2img": + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if mode == "img2text": + if image is None: + raise ValueError("`img2text` mode requires an image to be provided.") + + # Check provided latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + full_latents_available = latents is not None + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + + if full_latents_available: + individual_latents_available = ( + prompt_latents is not None or vae_latents is not None or clip_latents is not None + ) + if individual_latents_available: + logger.warning( + "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" + " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." + ) + # Check shape of full latents + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim + latents_expected_shape = (latents_dim,) + self.check_latents_shape("latents", latents, latents_expected_shape) + + # Check individual latent shapes, if present + if prompt_latents_available: + prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) + self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) + + if vae_latents_available: + vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) + self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) + + if clip_latents_available: + clip_latents_expected_shape = (1, self.image_encoder_projection_dim) + self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) + + if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: + if vae_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" + f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." + ) + + if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: + if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" + f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" + f" != {clip_latents.shape[0]}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + data_type: Optional[int] = 1, + num_inference_steps: int = 50, + guidance_scale: float = 8.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + num_prompts_per_image: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_latents: Optional[torch.FloatTensor] = None, + vae_latents: Optional[torch.FloatTensor] = None, + clip_latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. Required for text-conditioned image generation (`text2img`) mode. + image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch. Required for image-conditioned text generation + (`img2text`) mode. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + data_type (`int`, *optional*, defaults to 1): + The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type + embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the original [UniDiffuser + paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`, + which satisfies `w = w' + 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). Used in text-conditioned image generation (`text2img`) mode. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and + `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. + num_prompts_per_image (`int`, *optional*, defaults to 1): + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and + `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint + image-text generation. Can be used to tweak the same generation with different prompts. If not + provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note + that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override + the value of `prompt_latents`, `vae_latents`, and `clip_latents`. + prompt_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + vae_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + clip_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned + image generation (`text2img`) mode. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. Used in text-conditioned image generation (`text2img`) mode. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: + [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of generated texts. + """ + + # 0. Default height and width to unet + height = height or self.unet_resolution * self.vae_scale_factor + width = width or self.unet_resolution * self.vae_scale_factor + + # 1. Check inputs + # Recalculate mode for each call to the pipeline. + mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) + self.check_inputs( + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + + # 2. Define call parameters + batch_size, multiplier = self._infer_batch_size( + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + device = self._execution_device + reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Note that this differs from the formulation in the unidiffusers paper! + # do_classifier_free_guidance = guidance_scale > 1.0 + + # check if scheduler is in sigmas space + # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt, if available; otherwise prepare text latents + if latents is not None: + # Overwrite individual latents + vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) + + if mode in ["text2img"]: + # 3.1. Encode input prompt, if available + assert prompt is not None or prompt_embeds is not None + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=multiplier, + do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + else: + # 3.2. Prepare text latent variables, if input not available + prompt_embeds = self.prepare_text_latents( + batch_size=batch_size, + num_images_per_prompt=multiplier, + seq_len=self.text_encoder_seq_len, + hidden_size=self.text_encoder_hidden_size, + dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision + device=device, + generator=generator, + latents=prompt_latents, + ) + + if reduce_text_emb_dim: + prompt_embeds = self.text_decoder.encode(prompt_embeds) + + # 4. Encode image, if available; otherwise prepare image latents + if mode in ["img2text"]: + # 4.1. Encode images, if available + assert image is not None, "`img2text` requires a conditioning image" + # Encode image using VAE + image_vae = preprocess(image) + height, width = image_vae.shape[-2:] + image_vae_latents = self.encode_image_vae_latents( + image=image_vae, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG + generator=generator, + ) + + # Encode image using CLIP + image_clip_latents = self.encode_image_clip_latents( + image=image, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) + image_clip_latents = image_clip_latents.unsqueeze(1) + else: + # 4.2. Prepare image latent variables, if input not available + # Prepare image VAE latents in latent space + image_vae_latents = self.prepare_image_vae_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=vae_latents, + ) + + # Prepare image CLIP latents + image_clip_latents = self.prepare_image_clip_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + clip_img_dim=self.image_encoder_projection_dim, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=clip_latents, + ) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + # max_timestep = timesteps[0] + max_timestep = self.scheduler.config.num_train_timesteps + + # 6. Prepare latent variables + if mode == "joint": + latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) + elif mode in ["text2img", "img"]: + latents = self._combine(image_vae_latents, image_clip_latents) + elif mode in ["img2text", "text"]: + latents = prompt_embeds + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # predict the noise residual + # Also applies classifier-free guidance as described in the UniDiffuser paper + noise_pred = self._get_noise_pred( + mode, + latents, + t, + prompt_embeds, + image_vae_latents, + image_clip_latents, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + gen_image = None + gen_text = None + if mode == "joint": + image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) + + # Map latent VAE image back to pixel space + gen_image = self.decode_image_latents(image_vae_latents) + + # Generate text using the text decoder + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + gen_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + elif mode in ["text2img", "img"]: + image_vae_latents, image_clip_latents = self._split(latents, height, width) + gen_image = self.decode_image_latents(image_vae_latents) + elif mode in ["img2text", "text"]: + text_latents = latents + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + gen_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + + # 10. Convert to PIL + if output_type == "pil" and gen_image is not None: + gen_image = self.numpy_to_pil(gen_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (gen_image, gen_text) + + return ImageTextPipelineOutput(images=gen_image, text=gen_text) diff --git a/diffusers/pipelines/versatile_diffusion/__init__.py b/diffusers/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf9dcff59dbc922dcc7063a1e73560679a23696 --- /dev/null +++ b/diffusers/pipelines/versatile_diffusion/__init__.py @@ -0,0 +1,24 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) +else: + from .modeling_text_unet import UNetFlatConditionModel + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee0ffa87ae55f15d22abd747d2c05adf0a520e9 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c089885c4847a05ce6d911fa0daf7c0d21e739e Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e898886cbc4826e03e33dfcc66967db4eb80a6a Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e479452e3eac4394a8f7c5b6fc0d60f8e81fcca8 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-38.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de794825198274c1b623392c3227eb6861fc1df6 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23c1c4a3932b6b123e00ef2cfa4bbdeec18e2200 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc2f9ce1bd55c76ff4dd32cc0ed6e5ea4232d1f4 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be2e2d93a44b9fb566f30f1bb8a8ee82d531eb13 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-38.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec606f5452f82490738289d7612ac12e85e0b906 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa6cb7369180dd15121d0ff00972111e789a09a9 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-38.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29cda28d5ce8c8bbd7357cfe688c763eb1b1fcba Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0ea3eeb194cacb7b41cae359f5ccbd821f4ef1 Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-38.pyc differ diff --git a/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..82628104eba2a30b99487868e6bb6db5f4b7214c --- /dev/null +++ b/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,1931 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.activations import get_activation +from ...models.attention import Attention +from ...models.attention_processor import ( + AttentionProcessor, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, +) +from ...models.dual_transformer_2d import DualTransformer2DModel +from ...models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ...models.transformer_2d import Transformer2DModel +from ...models.unet_2d_condition import UNet2DConditionOutput +from ...utils import is_torch_version, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or + `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads`" + " because of a naming issue as described in" + " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing" + " `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:" + f" {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:" + f" {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + "Must provide the same number of `only_cross_attention` as `down_block_types`." + f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:" + f" {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" + f" {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:" + f" {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" + f" {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNetFlatConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires" + " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" + " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" + " the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the" + " keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires" + " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which" + " requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires" + " the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlockFlat + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6b5e7863ebb9b53ba741138b0829eab509888c --- /dev/null +++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,434 @@ +import inspect +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging +from ..pipeline_utils import DiffusionPipeline +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionMegaSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + @torch.no_grad() + def image_variation( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.image_variation(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100644 index 0000000000000000000000000000000000000000..5986d66a61e79e90dd7b86884f621f956cd52cba --- /dev/null +++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,557 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.utils.checkpoint +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + self.image_unet.register_to_config(dual_cross_attention=False) + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..154548df7542576dd00663ffa5d3e3e45d60da0b --- /dev/null +++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,399 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.utils.checkpoint +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + image_feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + image_feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: + prompt = list(prompt) + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: List[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba5d8451f2eb1f655bbf0d44edaa5b35edb2e87 --- /dev/null +++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,473 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import torch +import torch.utils.checkpoint +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/vq_diffusion/__init__.py b/diffusers/pipelines/vq_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9f14f000648347fe75a5bec0cb45d08c7d2ff9 --- /dev/null +++ b/diffusers/pipelines/vq_diffusion/__init__.py @@ -0,0 +1,5 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09dde4c595e70e9cb132ed3fb912db4c3bd6df04 Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3aa532e7fc4fd0468cbb26aee2414d8a47f03ea Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..485bd7fb224c1729864848edc4748f221a6a14da Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc differ diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-38.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8626f0de557481b0aa64736c366aefd78de8a955 Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-38.pyc differ diff --git a/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9147afe127e4b24366249c4a6e058abae9501050 --- /dev/null +++ b/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,330 @@ +# Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin, Transformer2DModel, VQModel +from ...schedulers import VQDiffusionScheduler +from ...utils import logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent + representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. VQ Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + transformer ([`Transformer2DModel`]): + Conditional transformer to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: Transformer2DModel, + scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_inference_steps: int = 100, + guidance_scale: float = 5.0, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor` of shape (batch), *optional*): + Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. + Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will + be generated of completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` + is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_vector_embeds - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_vector_embeds - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + sample = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + + # predict the un-noised image + # model_output == `log_p_x_0` + model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) + + model_output = self.truncate(model_output, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + model_output = model_output.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + embedding_channels = self.vqvae.config.vq_embed_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: + """ + Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The + lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv diff --git a/diffusers/schedulers/README.md b/diffusers/schedulers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..31ad27793e34783faabc222adf98691fb396a0d8 --- /dev/null +++ b/diffusers/schedulers/README.md @@ -0,0 +1,3 @@ +# Schedulers + +For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview). \ No newline at end of file diff --git a/diffusers/schedulers/__init__.py b/diffusers/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a07ce4baed20904b85b577aa3e4e38f6a47e945 --- /dev/null +++ b/diffusers/schedulers/__init__.py @@ -0,0 +1,92 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_scipy_available, + is_torch_available, + is_torchsde_available, +) + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 +else: + from .scheduling_consistency_models import CMStochasticIterativeScheduler + from .scheduling_ddim import DDIMScheduler + from .scheduling_ddim_inverse import DDIMInverseScheduler + from .scheduling_ddim_parallel import DDIMParallelScheduler + from .scheduling_ddpm import DDPMScheduler + from .scheduling_ddpm_parallel import DDPMParallelScheduler + from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler + from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler + from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler + from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_heun_discrete import HeunDiscreteScheduler + from .scheduling_ipndm import IPNDMScheduler + from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler + from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler + from .scheduling_karras_ve import KarrasVeScheduler + from .scheduling_pndm import PNDMScheduler + from .scheduling_repaint import RePaintScheduler + from .scheduling_sde_ve import ScoreSdeVeScheduler + from .scheduling_sde_vp import ScoreSdeVpScheduler + from .scheduling_unclip import UnCLIPScheduler + from .scheduling_unipc_multistep import UniPCMultistepScheduler + from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + from .scheduling_vq_diffusion import VQDiffusionScheduler + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 +else: + from .scheduling_ddim_flax import FlaxDDIMScheduler + from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler + from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler + from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler + from .scheduling_pndm_flax import FlaxPNDMScheduler + from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_utils_flax import ( + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, + ) + + +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 +else: + from .scheduling_lms_discrete import LMSDiscreteScheduler + +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 +else: + from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler diff --git a/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc b/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..362762968a25a04396fde2a7a4860bbc20611fe7 Binary files /dev/null and b/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/__init__.cpython-38.pyc b/diffusers/schedulers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c17d8944e5467f173b751019543179fda88048b3 Binary files /dev/null and b/diffusers/schedulers/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b64b74a63e9fa5b0e3b05398329935278c2a803 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39c1a91c269880d0102d079b01120e37c757c2d0 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b975de57edfed7090056442a7e4724cda7684301 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecb5f2378795cdcad48535aafd97f87ce12db7be Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17fec920c13eedf2d7c0328d2e0693333b1db134 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4886057f297b1712501cb5d0ccd2f1e405428d42 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46a27798da4b2df456c8d8d0d015517d87273e7a Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb0bb1da36378f3765d0ae27c2e2c7591d24750 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e1eead3241e8a93088835b30cd6a2d0630396dd Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98baefe397c397858316e924d379623d10c858b9 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd40929b0787c6de89a12a8d6f066be4f6376a63 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..372138392ebc3a92c60f7c0cfdd6f40f60ebc090 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab24c03522a608a4ef734479ea6bacdcb69992a1 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20bfffc6387c9225771f23b276c711bca2f9462b Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33af0455a39cb19426c580ebc3fedda2f9701dc9 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c5f9230bee39346c5e36b26e803233f14d1147 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff68dbd7dab59d0e06612679586d0c5573f16d1d Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bacdfacb5b5ff3acc288299a64a779845e02e12 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eb6787cacfee292a06081ca6cc0798334927bb1 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd562fd887d9cd0013e2362ff18b0be8246876eb Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c655806c4ba18fd049945ce87e6de3b6dfecb0e4 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db174e15986be6e0e45d0666ea49c053d2f1bf30 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfc122f4f88cf9d633c59de7b80eada8b9406814 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38fcc52b2651d95e4a3bc3948a616eddb75f9019 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..546cf379e7807e9fbb93527853aad76a650c68f9 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d24b1be603ff7f5b29281d7af88ab20ebeb2e65a Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7872ee39c72e643020387a52adb3a73a744e34a3 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e7e53da127d8e0eb101871d1817af6783aa67b4 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40852bd769015ee44bbbc88feb72871bbe77f9be Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088432df988b8a1ea3f02c00960e7f1a4f0c6174 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21861dc07327a721cdd0e2b592de288e3f73ff1f Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96998ea9b7d978917c0d4d73c24f37c28005242b Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f86170ad577b0e9effb01f1d92e959e40307f21c Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bdd88b6d6b9915c7b7ad66c2a314c7470ab6bd7 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950f9b500065d717bd4783cb98b1b52d4dcffba1 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09c016bf022cb8dda80ef0d5f12770d04d11784d Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f067a472f53466953c26bf0d6dff2cbf22cd8594 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ac775ea4d4a39e8341e3c085e4c2ff1315a2160 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563f6ce8a3584cb3c4952f4a6e2d672dd97cbc06 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b2a26de7b3b3e32c60d2288b51b074f9672d12b Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54a6f61a4841dcde2cc317da46dcbf60b9ed126c Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db247291040dba901942c21a4fa31f80aa19ec25 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aabb3c5ec409487e0379eb0fb452f48e96e5ea8d Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c56cf91e2c0521e7506ba243ea2d15bbcf47dac4 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..703c3a8c734fc7c6db66110c8f08b07c4111f448 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e759eb5e508a28c65538856ad8d53e4303b9901 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..161c265132aa504908d5f99641a2551223fed859 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..095240969bda27cf9fd8c904d513397a3cb2b7e3 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e556205119a8b1ddb53249243d9eeb18483b3941 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7266bc74eca09ddba385bfc43382be45d8e5128a Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f125b214f79a89ecf89cc177a9cd4d3d2c82a2a2 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc differ diff --git a/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8ed6a5124fddce0c1e78705e8663c9d06998f45 Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-38.pyc differ diff --git a/diffusers/schedulers/scheduling_consistency_models.py b/diffusers/schedulers/scheduling_consistency_models.py new file mode 100644 index 0000000000000000000000000000000000000000..fb296054d65b804af281dc99d940c8f0ba50e01b --- /dev/null +++ b/diffusers/schedulers/scheduling_consistency_models.py @@ -0,0 +1,380 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CMStochasticIterativeSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): + """ + Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the + paper [1]. + + [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" + https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based + Generative Models." https://arxiv.org/abs/2206.00364 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + sigma_min (`float`): + Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation. + sigma_max (`float`): + Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation. + sigma_data (`float`): + The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the + original implementation, which is also the original value suggested in the EDM paper. + s_noise (`float`): + The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, + 1.011]. This was set to 1.0 in the original implementation. + rho (`float`): + The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was + set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper. + clip_denoised (`bool`): + Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. + timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): + Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing + order. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 40, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + sigma_data: float = 0.5, + s_noise: float = 1.0, + rho: float = 7.0, + clip_denoised: bool = True, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + ramp = np.linspace(0, 1, num_train_timesteps) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + # setable values + self.num_inference_steps = None + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps) + self.custom_timesteps = False + self.is_scale_input_called = False + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + return indices.item() + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + Returns: + `torch.FloatTensor`: scaled input sample + """ + # Get sigma corresponding to timestep + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_idx = self.index_for_timestep(timestep) + sigma = self.sigmas[step_idx] + + sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + + self.is_scale_input_called = True + return sample + + def sigma_to_t(self, sigmas: Union[float, np.ndarray]): + """ + Gets scaled timesteps from the Karras sigmas, for input to the consistency model. + + Args: + sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas + Returns: + `float` or `np.ndarray`: scaled input timestep or scaled input timestep array + """ + if not isinstance(sigmas, np.ndarray): + sigmas = np.array(sigmas, dtype=np.float64) + + timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) + + return timesteps + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") + + # Follow DDPMScheduler custom timesteps logic + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.custom_timesteps = False + + # Map timesteps to Karras sigmas directly for multistep sampling + # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 + num_train_timesteps = self.config.num_train_timesteps + ramp = timesteps[::-1].copy() + ramp = ramp / (num_train_timesteps - 1) + sigmas = self._convert_to_karras(ramp) + timesteps = self.sigma_to_t(sigmas) + + sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + # Modified _convert_to_karras implementation that takes in ramp as argument + def _convert_to_karras(self, ramp): + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = self.config.sigma_min + sigma_max: float = self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def get_scalings(self, sigma): + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def get_scalings_for_boundary_condition(self, sigma): + """ + Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper. + This enforces the consistency model boundary condition. + + Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min. + + Args: + sigma (`torch.FloatTensor`): + The current sigma in the Karras sigma schedule. + Returns: + `tuple`: + A two-element tuple where c_skip (which weights the current sample) is the first element and c_out + (which weights the consistency model output) is the second element. + """ + sigma_min = self.config.sigma_min + sigma_data = self.config.sigma_data + + c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) + c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, *optional*): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + f" `{self.__class__}.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + sigma_min = self.config.sigma_min + sigma_max = self.config.sigma_max + + step_index = self.index_for_timestep(timestep) + + # sigma_next corresponds to next_t in original implementation + sigma = self.sigmas[step_index] + if step_index + 1 < self.config.num_train_timesteps: + sigma_next = self.sigmas[step_index + 1] + else: + # Set sigma_next to sigma_min + sigma_next = self.sigmas[-1] + + # Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) + + # 1. Denoise model output using boundary conditions + denoised = c_out * model_output + c_skip * sample + if self.config.clip_denoised: + denoised = denoised.clamp(-1, 1) + + # 2. Sample z ~ N(0, s_noise^2 * I) + # Noise is not used for onestep sampling. + if len(self.timesteps) > 1: + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + else: + noise = torch.zeros_like(model_output) + z = noise * self.config.s_noise + + sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) + + # 3. Return noisy sample + # tau = sigma_hat, eps = sigma_min + prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 + + if not return_dict: + return (prev_sample,) + + return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddim.py b/diffusers/schedulers/scheduling_ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..a93255ca600ef34da1b6c1691c4c5e9f7f86c2ed --- /dev/null +++ b/diffusers/schedulers/scheduling_ddim.py @@ -0,0 +1,515 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddim_flax.py b/diffusers/schedulers/scheduling_ddim_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..db248c33077bf502e31cb2ab97141744b828b514 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddim_flax.py @@ -0,0 +1,305 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, + get_velocity_common, +) + + +@flax.struct.dataclass +class DDIMSchedulerState: + common: CommonSchedulerState + final_alpha_cumprod: jnp.ndarray + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create( + cls, + common: CommonSchedulerState, + final_alpha_cumprod: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + +@dataclass +class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): + state: DDIMSchedulerState + + +class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + final_alpha_cumprod = ( + jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] + ) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DDIMSchedulerState.create( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def scale_model_input( + self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def set_timesteps( + self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDIMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDIMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + ) + + def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where( + prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + state: DDIMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + eta: float = 0.0, + return_dict: bool = True, + ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class + + Returns: + [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + + alphas_cumprod = state.common.alphas_cumprod + final_alpha_cumprod = state.final_alpha_cumprod + + # 2. compute alphas, betas + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(state, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if not return_dict: + return (prev_sample, state) + + return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: DDIMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def get_velocity( + self, + state: DDIMSchedulerState, + sample: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return get_velocity_common(state.common, sample, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddim_inverse.py b/diffusers/schedulers/scheduling_ddim_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..9db83eb992022c80b4fd1cc3769eedca219c96f4 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddim_inverse.py @@ -0,0 +1,349 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, deprecate + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): + """ + DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_zero (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`, + otherwise it uses the value of alpha at step `num_train_timesteps - 1`. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha + product. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + ignore_for_config = ["kwargs"] + _deprecated_kwargs = ["set_alpha_to_zero"] + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + **kwargs, + ): + if kwargs.get("set_alpha_to_zero", None) is not None: + deprecation_message = ( + "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead." + ) + deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False) + set_alpha_to_one = kwargs["set_alpha_to_zero"] + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in inverted ddim, we are looking into the next alphas_cumprod + # For the initial step, there is no current alphas_cumprod, and the index is out of bounds + # `set_alpha_to_one` decides whether we set this parameter simply to one + # in this case, self.step() just output the predicted noise + # or whether we use the initial alpha used in training the diffusion model. + self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + # Roll timesteps array by one to reflect reversed origin and destination semantics for each step + timesteps = np.roll(timesteps, 1) + timesteps[0] = int(timesteps[1] - step_ratio) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + # change original implementation to exactly match noise levels for analogous forward process + alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if not return_dict: + return (prev_sample, pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddim_parallel.py b/diffusers/schedulers/scheduling_ddim_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..db3ea0e1cca55f88d0a81d0311158929516cb038 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddim_parallel.py @@ -0,0 +1,642 @@ +# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput +class DDIMParallelSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + _is_ode_scheduler = True + + @register_to_config + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__ + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_variance(self, timestep, prev_timestep=None): + if prev_timestep is None: + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def _batch_get_variance(self, t, prev_t): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] + alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[DDIMParallelSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def batch_step_no_noise( + self, + model_output: torch.FloatTensor, + timesteps: List[int], + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + ) -> torch.FloatTensor: + """ + Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. + Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise + is pre-sampled by the pipeline. + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timesteps (`List[int]`): + current discrete timesteps in the diffusion chain. This is now a list of integers. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + + Returns: + `torch.FloatTensor`: sample tensor at previous timestep. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + assert eta == 0.0 + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + t = timesteps + prev_t = t - self.config.num_train_timesteps // self.num_inference_steps + + t = t.view(-1, *([1] * (model_output.ndim - 1))) + prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) + + # 1. compute alphas, betas + self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device) + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] + alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return prev_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddpm.py b/diffusers/schedulers/scheduling_ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b7d7aaa9c22a3a768d1aed131794e810400936 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddpm.py @@ -0,0 +1,513 @@ +# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDPMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`Optional[int]`): + the number of diffusion steps used when generating samples with a pre-trained model. If passed, then + `timesteps` must be `None`. + device (`str` or `torch.device`, optional): + the device to which the timesteps are moved to. + custom_timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + prev_t = self.previous_timestep(t) + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = torch.clamp(variance, min=1e-20) + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = variance + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = torch.log(variance) + variance = torch.exp(0.5 * variance) + elif variance_type == "fixed_large": + variance = current_beta_t + elif variance_type == "fixed_large_log": + # Glide max_log + variance = torch.log(current_beta_t) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = torch.log(variance) + max_log = torch.log(current_beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/diffusers/schedulers/scheduling_ddpm_flax.py b/diffusers/schedulers/scheduling_ddpm_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..529d2bd03a75403e298ec7a30808689a48cf5301 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddpm_flax.py @@ -0,0 +1,299 @@ +# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, + get_velocity_common, +) + + +@flax.struct.dataclass +class DDPMSchedulerState: + common: CommonSchedulerState + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) + + +@dataclass +class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): + state: DDPMSchedulerState + + +class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DDPMSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def scale_model_input( + self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def set_timesteps( + self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> DDPMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDPMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + ) + + def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): + alpha_prod_t = state.common.alphas_cumprod[t] + alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = jnp.clip(variance, a_min=1e-20) + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = jnp.log(jnp.clip(variance, a_min=1e-20)) + elif variance_type == "fixed_large": + variance = state.common.betas[t] + elif variance_type == "fixed_large_log": + # Glide max_log + variance = jnp.log(state.common.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = state.common.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + state: DDPMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: Optional[jax.random.KeyArray] = None, + return_dict: bool = True, + ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + key (`jax.random.KeyArray`): a PRNG key. + return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class + + Returns: + [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if key is None: + key = jax.random.PRNGKey(0) + + if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = state.common.alphas_cumprod[t] + alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " + " for the FlaxDDPMScheduler." + ) + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t + current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + def random_variance(): + split_key = jax.random.split(key, num=1) + noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) + return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise + + variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample, state) + + return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) + + def add_noise( + self, + state: DDPMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def get_velocity( + self, + state: DDPMSchedulerState, + sample: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return get_velocity_common(state.common, sample, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ddpm_parallel.py b/diffusers/schedulers/scheduling_ddpm_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..a92e175877d24057e49bf405e88185fd4297e6d2 --- /dev/null +++ b/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -0,0 +1,604 @@ +# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput +class DDPMParallelSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + _is_ode_scheduler = False + + @register_to_config + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.__init__ + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`Optional[int]`): + the number of diffusion steps used when generating samples with a pre-trained model. If passed, then + `timesteps` must be `None`. + device (`str` or `torch.device`, optional): + the device to which the timesteps are moved to. + custom_timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance + def _get_variance(self, t, predicted_variance=None, variance_type=None): + prev_t = self.previous_timestep(t) + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = torch.clamp(variance, min=1e-20) + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = variance + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = torch.log(variance) + variance = torch.exp(0.5 * variance) + elif variance_type == "fixed_large": + variance = current_beta_t + elif variance_type == "fixed_large_log": + # Glide max_log + variance = torch.log(current_beta_t) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = torch.log(variance) + max_log = torch.log(current_beta_t) + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[DDPMParallelSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def batch_step_no_noise( + self, + model_output: torch.FloatTensor, + timesteps: List[int], + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. + Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise + is pre-sampled by the pipeline. + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timesteps (`List[int]`): + current discrete timesteps in the diffusion chain. This is now a list of integers. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: sample tensor at previous timestep. + """ + t = timesteps + num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + prev_t = t - self.config.num_train_timesteps // num_inference_steps + + t = t.view(-1, *([1] * (model_output.ndim - 1))) + prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + pass + + # 1. compute alphas, betas + self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] + alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMParallelScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + return pred_prev_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/diffusers/schedulers/scheduling_deis_multistep.py b/diffusers/schedulers/scheduling_deis_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..36947294922b6cc0ecdc5bf7dc9c0772a056d03a --- /dev/null +++ b/diffusers/schedulers/scheduling_deis_multistep.py @@ -0,0 +1,568 @@ +# Copyright 2023 FLAIR Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the + polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification + enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More + variants of DEIS can be found in https://github.com/qsh-zh/deis. + + Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1` + reduces to DDIM. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set `thresholding=True` to use the dynamic thresholding. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and + `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` + algorithm_type (`str`, default `deis`): + the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in + the future + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "deis", + solver_type: str = "logrho", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DEIS + if algorithm_type not in ["deis"]: + if algorithm_type in ["dpmsolver", "dpmsolver++"]: + self.register_to_config(algorithm_type="deis") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["logrho"]: + if solver_type in ["midpoint", "heun", "bh1", "bh2"]: + self.register_to_config(solver_type="logrho") + else: + raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm DEIS needs. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DEISMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + if self.config.algorithm_type == "deis": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + return (sample - alpha_t * x0_pred) / sigma_t + else: + raise NotImplementedError("only support log-rho multistep deis now") + + def deis_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the first-order DEIS (equivalent to DDIM). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "deis": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + else: + raise NotImplementedError("only support log-rho multistep deis now") + return x_t + + def multistep_deis_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DEIS. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] + sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] + + rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 + + if self.config.algorithm_type == "deis": + + def ind_fn(t, b, c): + # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] + return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) + + coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) + coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) + + x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) + return x_t + else: + raise NotImplementedError("only support log-rho multistep deis now") + + def multistep_deis_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DEIS. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] + sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] + rho_t, rho_s0, rho_s1, rho_s2 = ( + sigma_t / alpha_t, + sigma_s0 / alpha_s0, + sigma_s1 / alpha_s1, + simga_s2 / alpha_s2, + ) + + if self.config.algorithm_type == "deis": + + def ind_fn(t, b, c, d): + # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] + numerator = t * ( + np.log(c) * (np.log(d) - np.log(t) + 1) + - np.log(d) * np.log(t) + + np.log(d) + + np.log(t) ** 2 + - 2 * np.log(t) + + 2 + ) + denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) + return numerator / denominator + + coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) + coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) + coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) + + x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) + + return x_t + else: + raise NotImplementedError("only support log-rho multistep deis now") + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DEIS. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_deis_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_deis_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/diffusers/schedulers/scheduling_dpmsolver_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..d7516fa601e17cdd5661039c181804d687a66f0e --- /dev/null +++ b/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,749 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse + diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the + second-order `sde-dpmsolver++`. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample, noise=noise + ) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample, noise=noise + ) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4ee67a7f5dbf8384eaedc0ede322284a413edd --- /dev/null +++ b/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -0,0 +1,622 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, +) + + +@flax.struct.dataclass +class DPMSolverMultistepSchedulerState: + common: CommonSchedulerState + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + # running values + model_outputs: Optional[jnp.ndarray] = None + lower_order_nums: Optional[jnp.int32] = None + prev_timestep: Optional[jnp.int32] = None + cur_sample: Optional[jnp.ndarray] = None + + @classmethod + def create( + cls, + common: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + +@dataclass +class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): + state: DPMSolverMultistepSchedulerState + + +class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # Currently we only support VP-type noise schedule + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + + # settings for DPM-Solver + if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}") + if self.config.solver_type not in ["midpoint", "heun"]: + raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}") + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DPMSolverMultistepSchedulerState.create( + common=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def set_timesteps( + self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + ) -> DPMSolverMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. + """ + + timesteps = ( + jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .astype(jnp.int32) + ) + + # initial running values + + model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) + lower_order_nums = jnp.int32(0) + prev_timestep = jnp.int32(-1) + cur_sample = jnp.zeros(shape, dtype=self.dtype) + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + model_outputs=model_outputs, + lower_order_nums=lower_order_nums, + prev_timestep=prev_timestep, + cur_sample=cur_sample, + ) + + def convert_model_output( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the converted model output. + """ + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + dynamic_max_val = jnp.percentile( + jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + ) + dynamic_max_val = jnp.maximum( + dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + ) + x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + + def dpm_solver_first_order_update( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0 = prev_timestep, timestep + m0 = model_output + lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] + alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 + return x_t + + def multistep_dpm_solver_second_order_update( + self, + state: DPMSolverMultistepSchedulerState, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] + alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + state: DPMSolverMultistepSchedulerState, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + state.lambda_t[t], + state.lambda_t[s0], + state.lambda_t[s1], + state.lambda_t[s2], + ) + alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process + from the learned model outputs (most often the predicted noise). + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class + + Returns: + [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) + + model_output = self.convert_model_output(state, model_output, timestep, sample) + + model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) + model_outputs_new = model_outputs_new.at[-1].set(model_output) + state = state.replace( + model_outputs=model_outputs_new, + prev_timestep=prev_timestep, + cur_sample=sample, + ) + + def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + return self.dpm_solver_first_order_update( + state, + state.model_outputs[-1], + state.timesteps[step_index], + state.prev_timestep, + state.cur_sample, + ) + + def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) + return self.multistep_dpm_solver_second_order_update( + state, + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array( + [ + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], + ] + ) + return self.multistep_dpm_solver_third_order_update( + state, + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + step_2_output = step_2(state) + step_3_output = step_3(state) + + if self.config.solver_order == 2: + return step_2_output + elif self.config.lower_order_final and len(state.timesteps) < 15: + return jax.lax.select( + state.lower_order_nums < 2, + step_2_output, + jax.lax.select( + step_index == len(state.timesteps) - 2, + step_2_output, + step_3_output, + ), + ) + else: + return jax.lax.select( + state.lower_order_nums < 2, + step_2_output, + step_3_output, + ) + + step_1_output = step_1(state) + step_23_output = step_23(state) + + if self.config.solver_order == 1: + prev_sample = step_1_output + + elif self.config.lower_order_final and len(state.timesteps) < 15: + prev_sample = jax.lax.select( + state.lower_order_nums < 1, + step_1_output, + jax.lax.select( + step_index == len(state.timesteps) - 1, + step_1_output, + step_23_output, + ), + ) + + else: + prev_sample = jax.lax.select( + state.lower_order_nums < 1, + step_1_output, + step_23_output, + ) + + state = state.replace( + lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input( + self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def add_noise( + self, + state: DPMSolverMultistepSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c98a551d665a05d4cbab8ccbdef6785fa2ed09 --- /dev/null +++ b/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -0,0 +1,707 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): + """ + DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self.use_karras_sigmas = use_karras_sigmas + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() + self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " + "'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = timesteps.copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif "sde" in self.config.algorithm_type: + raise NotImplementedError( + f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif "sde" in self.config.algorithm_type: + raise NotImplementedError( + f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = ( + self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + ) + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample, noise=noise + ) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample, noise=noise + ) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_dpmsolver_sde.py b/diffusers/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 0000000000000000000000000000000000000000..a31e97b6965169823634afe8984866a9f7d03ba3 --- /dev/null +++ b/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,509 @@ +# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torchsde + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): + """ + Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion + implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas + self.noise_sampler = None + self.noise_sampler_seed = noise_sampler_seed + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 + else: + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + + return indices[pos].item() + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + sigma = self.sigmas[step_index] + sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma + sample = sample / ((sigma_input**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + timesteps = torch.from_numpy(timesteps) + second_order_timesteps = torch.from_numpy(second_order_timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + timesteps[1::2] = second_order_timesteps + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty first order variables + self.sample = None + self.mid_point_sigma = None + + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + + def _second_order_timesteps(self, sigmas, log_sigmas): + def sigma_fn(_t): + return np.exp(-_t) + + def t_fn(_sigma): + return -np.log(_sigma) + + midpoint_ratio = 0.5 + t = t_fn(sigmas) + delta_time = np.diff(t) + t_proposed = t[:-1] + delta_time * midpoint_ratio + sig_proposed = sigma_fn(t_proposed) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) + return timesteps + + # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + s_noise: float = 1.0, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model. + timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain. + sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process. + return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True. + s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + + # Create a noise sampler if it hasn't been created yet + if self.noise_sampler is None: + min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() + self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) + + # Define functions to compute sigma and t from each other + def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: + return _t.neg().exp() + + def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: + return _sigma.log().neg() + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order + sigma = self.sigmas[step_index - 1] + sigma_next = self.sigmas[step_index] + + # Set the midpoint and step size for the current step + midpoint_ratio = 0.5 + t, t_next = t_fn(sigma), t_fn(sigma_next) + delta_time = t_next - t + t_proposed = t + delta_time * midpoint_ratio + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if sigma_next == 0: + derivative = (sample - pred_original_sample) / sigma + dt = sigma_next - sigma + prev_sample = sample + derivative * dt + else: + if self.state_in_first_order: + t_next = t_proposed + else: + sample = self.sample + + sigma_from = sigma_fn(t) + sigma_to = sigma_fn(t_next) + sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + ancestral_t = t_fn(sigma_down) + prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( + t - ancestral_t + ).expm1() * pred_original_sample + prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up + + if self.state_in_first_order: + # store for 2nd order step + self.sample = sample + self.mid_point_sigma = sigma_fn(t_next) + else: + # free for "first order mode" + self.sample = None + self.mid_point_sigma = None + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/diffusers/schedulers/scheduling_dpmsolver_singlestep.py new file mode 100644 index 0000000000000000000000000000000000000000..93975a27fc6e3899c009b5576ed74753ea62abbb --- /dev/null +++ b/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -0,0 +1,737 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the singlestep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable + this to use up all the function evaluations. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. + + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.sample = None + self.order_list = self.get_order_list(num_train_timesteps) + + def get_order_list(self, num_inference_steps: int) -> List[int]: + """ + Computes the solver order at each time step. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + steps = num_inference_steps + order = self.config.solver_order + if self.config.lower_order_final: + if order == 3: + if steps % 3 == 0: + orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] + elif steps % 3 == 1: + orders = [1, 2, 3] * (steps // 3) + [1] + else: + orders = [1, 2, 3] * (steps // 3) + [1, 2] + elif order == 2: + if steps % 2 == 0: + orders = [1, 2] * (steps // 2) + else: + orders = [1, 2] * (steps // 2) + [1] + elif order == 1: + orders = [1] * steps + else: + if order == 3: + orders = [1, 2, 3] * (steps // 3) + elif order == 2: + orders = [1, 2] * (steps // 2) + elif order == 1: + orders = [1] * steps + return orders + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + self.timesteps = torch.from_numpy(timesteps).to(device) + self.model_outputs = [None] * self.config.solver_order + self.sample = None + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warn( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverSinglestepScheduler." + ) + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + return x_t + + def singlestep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the second-order singlestep DPM-Solver. + + It computes the solution at time `prev_timestep` from the time `timestep_list[-2]`. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1] + sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1] + h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m1, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s1) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s1) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + return x_t + + def singlestep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order singlestep DPM-Solver. + + It computes the solution at time `prev_timestep` from the time `timestep_list[-3]`. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2] + sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] + h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m2 + D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) + D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s2) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s2) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s2) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + """ + One step for the singlestep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + order (`int`): + the solver order at this step. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + if order == 1: + return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) + elif order == 2: + return self.singlestep_dpm_solver_second_order_update( + model_output_list, timestep_list, prev_timestep, sample + ) + elif order == 3: + return self.singlestep_dpm_solver_third_order_update( + model_output_list, timestep_list, prev_timestep, sample + ) + else: + raise ValueError(f"Order must be 1, 2, 3, got {order}") + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the singlestep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + order = self.order_list[step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + + # For single-step solvers, we use the initial value at each time with order = 1. + if order == 1: + self.sample = sample + + timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] + prev_sample = self.singlestep_dpm_solver_update( + self.model_outputs, timestep_list, prev_timestep, self.sample, order + ) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/diffusers/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..065f657032e6ef21bd022f938a3b1e7ada334436 --- /dev/null +++ b/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,358 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete +class EulerAncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise + a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) + + prev_sample = prev_sample + noise * sigma_up + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_euler_discrete.py b/diffusers/schedulers/scheduling_euler_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..cb126d4b953cd28e23d048c4f1e2cf8ed90cdac0 --- /dev/null +++ b/diffusers/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,432 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +class EulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `"epsilon"`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + interpolation_type (`str`, default `"linear"`, optional): + interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of + [`"linear"`, `"log_linear"`]. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + self.use_karras_sigmas = use_karras_sigmas + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + sample = sample / ((sigma**2 + 1) ** 0.5) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_heun_discrete.py b/diffusers/schedulers/scheduling_heun_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..5f694fd60fc9f7f596f0d28d19cc231a26712fd1 --- /dev/null +++ b/diffusers/schedulers/scheduling_heun_discrete.py @@ -0,0 +1,426 @@ +# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf). + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, + clip_sample: Optional[bool] = False, + clip_sample_range: float = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") + elif beta_schedule == "exp": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 + else: + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + + return indices[pos].item() + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + timesteps = torch.from_numpy(timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.dt is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[step_index - 1] + sigma_next = self.sigmas[step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_next + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_next + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 2. 2nd order / Heun's method + derivative = (sample - pred_original_sample) / sigma_next + derivative = (self.prev_derivative + derivative) / 2 + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_ipndm.py b/diffusers/schedulers/scheduling_ipndm.py new file mode 100644 index 0000000000000000000000000000000000000000..80e521590782de6bc14e9b8c29642c7595fafc93 --- /dev/null +++ b/diffusers/schedulers/scheduling_ipndm.py @@ -0,0 +1,161 @@ +# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class IPNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion + [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + """ + + order = 1 + + @register_to_config + def __init__( + self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None + ): + # set `betas`, `alphas`, `timesteps` + self.set_timesteps(num_train_timesteps) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.ets = [] + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] + steps = torch.cat([steps, torch.tensor([0.0])]) + + if self.config.trained_betas is not None: + self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) + else: + self.betas = torch.sin(steps * math.pi / 2) ** 2 + + self.alphas = (1.0 - self.betas**2) ** 0.5 + + timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] + self.timesteps = timesteps.to(device) + + self.ets = [] + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep_index = (self.timesteps == timestep).nonzero().item() + prev_timestep_index = timestep_index + 1 + + ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] + self.ets.append(ets) + + if len(self.ets) == 1: + ets = self.ets[-1] + elif len(self.ets) == 2: + ets = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): + alpha = self.alphas[timestep_index] + sigma = self.betas[timestep_index] + + next_alpha = self.alphas[prev_timestep_index] + next_sigma = self.betas[prev_timestep_index] + + pred = (sample - sigma * ets) / max(alpha, 1e-8) + prev_sample = next_alpha * pred + ets * next_sigma + + return prev_sample + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf9379b9b90a53e3c8aad20a69e9ab7bffc691e --- /dev/null +++ b/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -0,0 +1,420 @@ +# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: + https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 + + Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 + else: + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + + return indices[pos].item() + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + else: + sigma = self.sigmas_interpol[step_index - 1] + + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) + + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + + # compute up and down sigmas + sigmas_next = sigmas.roll(-1) + sigmas_next[-1] = 0.0 + sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 + sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 + sigmas_down[-1] = 0.0 + + # compute interpolated sigmas + sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() + sigmas_interpol[-2:] = 0.0 + + # set sigmas + self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) + self.sigmas_interpol = torch.cat( + [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] + ) + self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) + self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) + + if str(device).startswith("mps"): + # mps does not support float64 + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + timesteps = torch.from_numpy(timesteps).to(device) + + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) + interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() + + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) + + self.sample = None + + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + + def sigma_to_t(self, sigma): + # get log sigma + log_sigma = sigma.log() + + # get distribution + dists = log_sigma - self.log_sigmas[:, None] + + # get sigmas range + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = self.log_sigmas[low_idx] + high = self.log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + return t + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_interpol = self.sigmas_interpol[step_index] + sigma_up = self.sigmas_up[step_index] + sigma_down = self.sigmas_down[step_index - 1] + else: + # 2nd order / KPDM2's method + sigma = self.sigmas[step_index - 1] + sigma_interpol = self.sigmas_interpol[step_index - 1] + sigma_up = self.sigmas_up[step_index - 1] + sigma_down = self.sigmas_down[step_index - 1] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + device = model_output.device + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_interpol - sigma_hat + + # store for 2nd order step + self.sample = sample + self.dt = dt + prev_sample = sample + derivative * dt + else: + # DPM-Solver-2 + # 2. Convert to an ODE derivative for 2nd order + derivative = (sample - pred_original_sample) / sigma_interpol + # 3. delta timestep + dt = sigma_down - sigma_hat + + sample = self.sample + self.sample = None + + prev_sample = sample + derivative * dt + prev_sample = prev_sample + noise * sigma_up + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/diffusers/schedulers/scheduling_k_dpm_2_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a1b4e6640d1bc10ef6475bde39b5f39a87ec80 --- /dev/null +++ b/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -0,0 +1,401 @@ +# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: + https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 + + Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 + else: + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + + return indices[pos].item() + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + else: + sigma = self.sigmas_interpol[step_index] + + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) + + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + + # interpolate sigmas + sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() + + self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) + self.sigmas_interpol = torch.cat( + [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] + ) + + if str(device).startswith("mps"): + # mps does not support float64 + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + timesteps = torch.from_numpy(timesteps).to(device) + + # interpolate timesteps + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) + interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() + + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) + + self.sample = None + + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + + def sigma_to_t(self, sigma): + # get log sigma + log_sigma = sigma.log() + + # get distribution + dists = log_sigma - self.log_sigmas[:, None] + + # get sigmas range + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = self.log_sigmas[low_idx] + high = self.log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + return t + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_interpol = self.sigmas_interpol[step_index + 1] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order / KDPM2's method + sigma = self.sigmas[step_index - 1] + sigma_interpol = self.sigmas_interpol[step_index] + sigma_next = self.sigmas[step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_interpol - sigma_hat + + # store for 2nd order step + self.sample = sample + else: + # DPM-Solver-2 + # 2. Convert to an ODE derivative for 2nd order + derivative = (sample - pred_original_sample) / sigma_interpol + + # 3. delta timestep + dt = sigma_next - sigma_hat + + sample = self.sample + self.sample = None + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_karras_ve.py b/diffusers/schedulers/scheduling_karras_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..87f6514a4e93e4a75bd6228ed852306b8c005c3d --- /dev/null +++ b/diffusers/schedulers/scheduling_karras_ve.py @@ -0,0 +1,232 @@ +# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +class KarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Derivative of predicted original image sample (x_0). + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + derivative: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class KarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + + """ + + order = 2 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + # setable values + self.num_inference_steps: int = None + self.timesteps: np.IntTensor = None + self.schedule: torch.FloatTensor = None # sigma(t_i) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + schedule = [ + ( + self.config.sigma_max**2 + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) + for i in self.timesteps + ] + self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) + + def add_noise_to_input( + self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[torch.FloatTensor, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + model_output: torch.FloatTensor, + sigma_hat: float, + sigma_prev: float, + sample_hat: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor`): TODO + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class + + KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + Returns: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) + + def step_correct( + self, + model_output: torch.FloatTensor, + sigma_hat: float, + sigma_prev: float, + sample_hat: torch.FloatTensor, + sample_prev: torch.FloatTensor, + derivative: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor`): TODO + sample_prev (`torch.FloatTensor`): TODO + derivative (`torch.FloatTensor`): TODO + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/diffusers/schedulers/scheduling_karras_ve_flax.py b/diffusers/schedulers/scheduling_karras_ve_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..45c0dbddf7efd22df21cc9859e68d62b54aa8609 --- /dev/null +++ b/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -0,0 +1,237 @@ +# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils_flax import FlaxSchedulerMixin + + +@flax.struct.dataclass +class KarrasVeSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + schedule: Optional[jnp.ndarray] = None # sigma(t_i) + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxKarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Derivative of predicted original image sample (x_0). + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + """ + + prev_sample: jnp.ndarray + derivative: jnp.ndarray + state: KarrasVeSchedulerState + + +class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + """ + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + ): + pass + + def create_state(self): + return KarrasVeSchedulerState.create() + + def set_timesteps( + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> KarrasVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`KarrasVeSchedulerState`): + the `FlaxKarrasVeScheduler` state data class. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() + schedule = [ + ( + self.config.sigma_max**2 + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) + for i in timesteps + ] + + return state.replace( + num_inference_steps=num_inference_steps, + schedule=jnp.array(schedule, dtype=jnp.float32), + timesteps=timesteps, + ) + + def add_noise_to_input( + self, + state: KarrasVeSchedulerState, + sample: jnp.ndarray, + sigma: float, + key: random.KeyArray, + ) -> Tuple[jnp.ndarray, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + key = random.split(key, num=1) + eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class + + Returns: + [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion + chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def step_correct( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + sample_prev: jnp.ndarray, + derivative: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO + derivative (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/diffusers/schedulers/scheduling_lms_discrete.py b/diffusers/schedulers/scheduling_lms_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..d58d4ce45bd17645b86905c1ae36ce937015fc29 --- /dev/null +++ b/diffusers/schedulers/scheduling_lms_discrete.py @@ -0,0 +1,413 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete +class LMSDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # setable values + self.num_inference_steps = None + self.use_karras_sigmas = use_karras_sigmas + self.set_timesteps(num_train_timesteps, None) + self.derivatives = [] + self.is_scale_input_called = False + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def get_lms_coefficient(self, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + self.derivatives = [] + + # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + order: int = 4, + return_dict: bool = True, + ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if not self.is_scale_input_called: + warnings.warn( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + if len(self.derivatives) > order: + self.derivatives.pop(0) + + # 3. Compute linear multistep coefficients + order = min(step_index + 1, order) + lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) + ) + + if not return_dict: + return (prev_sample,) + + return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_lms_discrete_flax.py b/diffusers/schedulers/scheduling_lms_discrete_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..f96e602afe121a09876b0ff7db1d3192e441e32a --- /dev/null +++ b/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -0,0 +1,283 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +@flax.struct.dataclass +class LMSDiscreteSchedulerState: + common: CommonSchedulerState + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + sigmas: jnp.ndarray + num_inference_steps: Optional[int] = None + + # running values + derivatives: Optional[jnp.ndarray] = None + + @classmethod + def create( + cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + ): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + + +@dataclass +class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): + state: LMSDiscreteSchedulerState + + +class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = sigmas.max() + + return LMSDiscreteSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) + + def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + + Args: + state (`LMSDiscreteSchedulerState`): + the `FlaxLMSDiscreteScheduler` state data class instance. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + timestep (`int`): + current discrete timestep in the diffusion chain. + + Returns: + `jnp.ndarray`: scaled input sample + """ + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps( + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + ) -> LMSDiscreteSchedulerState: + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`LMSDiscreteSchedulerState`): + the `FlaxLMSDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + + low_idx = jnp.floor(timesteps).astype(jnp.int32) + high_idx = jnp.ceil(timesteps).astype(jnp.int32) + + frac = jnp.mod(timesteps, 1.0) + + sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + timesteps = timesteps.astype(jnp.int32) + + # initial running values + derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) + + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + num_inference_steps=num_inference_steps, + derivatives=derivatives, + ) + + def step( + self, + state: LMSDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + order: int = 4, + return_dict: bool = True, + ) -> Union[FlaxLMSSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class + + Returns: + [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + sigma = state.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) + if len(state.derivatives) > order: + state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: LMSDiscreteSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sigma = state.sigmas[timesteps].flatten() + sigma = broadcast_to_shape_from_left(sigma, noise.shape) + + noisy_samples = original_samples + noise * sigma + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_pndm.py b/diffusers/schedulers/scheduling_pndm.py new file mode 100644 index 0000000000000000000000000000000000000000..794eb3674c1bb5533b938b00b08d48cd5192c317 --- /dev/null +++ b/diffusers/schedulers/scheduling_pndm.py @@ -0,0 +1,462 @@ +# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class PNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) + or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = "epsilon", + timestep_spacing: str = "leading", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.cur_model_output = 0 + self.counter = 0 + self.cur_sample = None + self.ets = [] + + # setable values + self.num_inference_steps = None + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.prk_timesteps = None + self.plms_timesteps = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + self.num_inference_steps = num_inference_steps + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + self._timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( + np.int64 + ) + self._timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ + ::-1 + ].copy() + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.ets = [] + self.counter = 0 + self.cur_model_output = 0 + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) + + def step_prk( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output += 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = 0 + + # cur_sample should not be `None` + cur_sample = self.cur_sample if self.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def step_plms( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.config.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = None + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_pndm_flax.py b/diffusers/schedulers/scheduling_pndm_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..c654f2de8dd3e4f96403cce4b9db8f8b7b69861f --- /dev/null +++ b/diffusers/schedulers/scheduling_pndm_flax.py @@ -0,0 +1,511 @@ +# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, +) + + +@flax.struct.dataclass +class PNDMSchedulerState: + common: CommonSchedulerState + final_alpha_cumprod: jnp.ndarray + + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + prk_timesteps: Optional[jnp.ndarray] = None + plms_timesteps: Optional[jnp.ndarray] = None + + # running values + cur_model_output: Optional[jnp.ndarray] = None + counter: Optional[jnp.int32] = None + cur_sample: Optional[jnp.ndarray] = None + ets: Optional[jnp.ndarray] = None + + @classmethod + def create( + cls, + common: CommonSchedulerState, + final_alpha_cumprod: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + +@dataclass +class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): + state: PNDMSchedulerState + + +class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. + """ + + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] + + dtype: jnp.dtype + pndm_order: int + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + final_alpha_cumprod = ( + jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] + ) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return PNDMSchedulerState.create( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) + + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`PNDMSchedulerState`): + the `FlaxPNDMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. + """ + + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + + prk_timesteps = jnp.array([], dtype=jnp.int32) + plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] + + else: + prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( + jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), + self.pndm_order, + ) + + prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] + plms_timesteps = _timesteps[:-3][::-1] + + timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) + + # initial running values + + cur_model_output = jnp.zeros(shape, dtype=self.dtype) + counter = jnp.int32(0) + cur_sample = jnp.zeros(shape, dtype=self.dtype) + ets = jnp.zeros((4,) + shape, dtype=self.dtype) + + return state.replace( + timesteps=timesteps, + num_inference_steps=num_inference_steps, + prk_timesteps=prk_timesteps, + plms_timesteps=plms_timesteps, + cur_model_output=cur_model_output, + counter=counter, + cur_sample=cur_sample, + ets=ets, + ) + + def scale_model_input( + self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def step( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class + + Returns: + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.config.skip_prk_steps: + prev_sample, state = self.step_plms(state, model_output, timestep, sample) + else: + prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) + plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) + + cond = state.counter < len(state.prk_timesteps) + + prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) + + state = state.replace( + cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), + ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), + cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), + counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) + + def step_prk( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class + + Returns: + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = jnp.where( + state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + ) + prev_timestep = timestep - diff_to_prev + timestep = state.prk_timesteps[state.counter // 4 * 4] + + model_output = jax.lax.select( + (state.counter % 4) != 3, + model_output, # remainder 0, 1, 2 + state.cur_model_output + 1 / 6 * model_output, # remainder 3 + ) + + state = state.replace( + cur_model_output=jax.lax.select_n( + state.counter % 4, + state.cur_model_output + 1 / 6 * model_output, # remainder 0 + state.cur_model_output + 1 / 3 * model_output, # remainder 1 + state.cur_model_output + 1 / 3 * model_output, # remainder 2 + jnp.zeros_like(state.cur_model_output), # remainder 3 + ), + ets=jax.lax.select( + (state.counter % 4) == 0, + state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 + state.ets, # remainder 1, 2, 3 + ), + cur_sample=jax.lax.select( + (state.counter % 4) == 0, + sample, # remainder 0 + state.cur_sample, # remainder 1, 2, 3 + ), + ) + + cur_sample = state.cur_sample + prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) + state = state.replace(counter=state.counter + 1) + + return (prev_sample, state) + + def step_plms( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class + + Returns: + [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before + + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) + + # Reference: + # if state.counter != 1: + # state.ets.append(model_output) + # else: + # prev_timestep = timestep + # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) + timestep = jnp.where( + state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + ) + + # Reference: + # if len(state.ets) == 1 and state.counter == 0: + # model_output = model_output + # state.cur_sample = sample + # elif len(state.ets) == 1 and state.counter == 1: + # model_output = (model_output + state.ets[-1]) / 2 + # sample = state.cur_sample + # state.cur_sample = None + # elif len(state.ets) == 2: + # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + # elif len(state.ets) == 3: + # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + # else: + # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) + + state = state.replace( + ets=jax.lax.select( + state.counter != 1, + state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 + state.ets, # counter 1 + ), + cur_sample=jax.lax.select( + state.counter != 1, + sample, # counter != 1 + state.cur_sample, # counter 1 + ), + ) + + state = state.replace( + cur_model_output=jax.lax.select_n( + jnp.clip(state.counter, 0, 4), + model_output, # counter 0 + (model_output + state.ets[-1]) / 2, # counter 1 + (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 + (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 + (1 / 24) + * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 + ), + ) + + sample = state.cur_sample + model_output = state.cur_model_output + prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) + state = state.replace(counter=state.counter + 1) + + return (prev_sample, state) + + def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where( + prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + def add_noise( + self, + state: PNDMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_repaint.py b/diffusers/schedulers/scheduling_repaint.py new file mode 100644 index 0000000000000000000000000000000000000000..41e7450d2df68c40c3b4f49669513832e443c5e3 --- /dev/null +++ b/diffusers/schedulers/scheduling_repaint.py @@ -0,0 +1,344 @@ +# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +class RePaintSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from + the current timestep. `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: torch.FloatTensor + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class RePaintScheduler(SchedulerMixin, ConfigMixin): + """ + RePaint is a schedule for DDPM inpainting inside a given mask. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. + eta (`float`): + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and + 1.0 is DDPM scheduler respectively. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + eta: float = 0.0, + trained_betas: Optional[np.ndarray] = None, + clip_sample: bool = True, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + self.final_alpha_cumprod = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.eta = eta + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, + num_inference_steps: int, + jump_length: int = 10, + jump_n_sample: int = 10, + device: Union[str, torch.device] = None, + ): + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + self.num_inference_steps = num_inference_steps + + timesteps = [] + + jumps = {} + for j in range(0, num_inference_steps - jump_length, jump_length): + jumps[j] = jump_n_sample - 1 + + t = num_inference_steps + while t >= 1: + t = t - 1 + timesteps.append(t) + + if jumps.get(t, 0) > 0: + jumps[t] = jumps[t] - 1 + for _ in range(jump_length): + t = t + 1 + timesteps.append(t) + + timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t): + prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from + # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get + # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add + # variance to pred_sample + # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf + # without eta. + # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + original_image: torch.FloatTensor, + mask: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RePaintSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned + diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + original_image (`torch.FloatTensor`): + the original image to inpaint on. + mask (`torch.FloatTensor`): + the mask where 0.0 values define which part of the original image to inpaint (change). + generator (`torch.Generator`, *optional*): random number generator. + return_dict (`bool`): option for returning tuple rather than + DDPMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we + # substitute formula (7) in the algorithm coming from DDPM paper + # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. + # DDIM schedule gives the same results as DDPM with eta = 1.0 + # Noise is being reused in 7. and 8., but no impact on quality has + # been observed. + + # 5. Add noise + device = model_output.device + noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) + std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 + + variance = 0 + if t > 0 and self.eta > 0: + variance = std_dev_t * noise + + # 6. compute "direction pointing to x_t" of formula (12) + # from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output + + # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance + + # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf + prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise + + # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf + pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part + + if not return_dict: + return ( + pred_prev_sample, + pred_original_sample, + ) + + return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + def undo_step(self, sample, timestep, generator=None): + n = self.config.num_train_timesteps // self.num_inference_steps + + for i in range(n): + beta = self.betas[timestep + i] + if sample.device.type == "mps": + # randn does not work reproducibly on mps + noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator) + noise = noise.to(sample.device) + else: + noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + + # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf + sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise + + return sample + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_sde_ve.py b/diffusers/schedulers/scheduling_sde_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..339edfbb02eb6ac0f79b3969004418bb29e212b5 --- /dev/null +++ b/diffusers/schedulers/scheduling_sde_ve.py @@ -0,0 +1,288 @@ +# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@dataclass +class SdeVeOutput(BaseOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + prev_sample: torch.FloatTensor + prev_sample_mean: torch.FloatTensor + + +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + + # setable values + self.timesteps = None + + self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps( + self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None + ): + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): + final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) + + def set_sigmas( + self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None + ): + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): + final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): + final timestep value (overrides value given at Scheduler instantiation). + + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if self.timesteps is None: + self.set_timesteps(num_inference_steps, sampling_eps) + + self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) + self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) + + def get_adjacent_sigma(self, timesteps, t): + return torch.where( + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), + ) + + def step_pred( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * torch.ones( + sample.shape[0], device=sample.device + ) # torch.repeat_interleave(timestep, sample.shape[0]) + timesteps = (timestep * (len(self.timesteps) - 1)).long() + + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + + sigma = self.discrete_sigmas[timesteps].to(sample.device) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) + drift = torch.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion.unsqueeze(-1) + drift = drift - diffusion**2 * model_output + + # equation 6: sample noise for the diffusion term of + noise = randn_tensor( + sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype + ) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean) + + return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) + + def step_correct( + self, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device) + + # compute step size from the model_output, the noise, and the snr + grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) + # self.repeat_scalar(step_size, sample.shape[0]) + + # compute corrected sample: model_output term and noise term + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size.unsqueeze(-1) + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + timesteps = timesteps.to(original_samples.device) + sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] + noise = ( + noise * sigmas[:, None, None, None] + if noise is not None + else torch.randn_like(original_samples) * sigmas[:, None, None, None] + ) + noisy_samples = noise + original_samples + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_sde_ve_flax.py b/diffusers/schedulers/scheduling_sde_ve_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..b6240559fc88fa45e4612dc3005ba66e10d3269d --- /dev/null +++ b/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -0,0 +1,279 @@ +# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left + + +@flax.struct.dataclass +class ScoreSdeVeSchedulerState: + # setable values + timesteps: Optional[jnp.ndarray] = None + discrete_sigmas: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxSdeVeOutput(FlaxSchedulerOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + state (`ScoreSdeVeSchedulerState`): + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + state: ScoreSdeVeSchedulerState + prev_sample: jnp.ndarray + prev_sample_mean: Optional[jnp.ndarray] = None + + +class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + """ + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + ): + pass + + def create_state(self): + state = ScoreSdeVeSchedulerState.create() + return self.set_sigmas( + state, + self.config.num_train_timesteps, + self.config.sigma_min, + self.config.sigma_max, + self.config.sampling_eps, + ) + + def set_timesteps( + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None + ) -> ScoreSdeVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): + final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + + timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) + return state.replace(timesteps=timesteps) + + def set_sigmas( + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + sigma_min: float = None, + sigma_max: float = None, + sampling_eps: float = None, + ) -> ScoreSdeVeSchedulerState: + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): + final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): + final timestep value (overrides value given at Scheduler instantiation). + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if state.timesteps is None: + state = self.set_timesteps(state, num_inference_steps, sampling_eps) + + discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) + sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) + + return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) + + def get_adjacent_sigma(self, state, timesteps, t): + return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) + + def step_pred( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + return_dict: bool = True, + ) -> Union[FlaxSdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * jnp.ones( + sample.shape[0], + ) + timesteps = (timestep * (len(state.timesteps) - 1)).long() + + sigma = state.discrete_sigmas[timesteps] + adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) + drift = jnp.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + diffusion = diffusion.flatten() + diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) + drift = drift - diffusion**2 * model_output + + # equation 6: sample noise for the diffusion term of + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) + + def step_correct( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + key: random.KeyArray, + return_dict: bool = True, + ) -> Union[FlaxSdeVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + + # compute step size from the model_output, the noise, and the snr + grad_norm = jnp.linalg.norm(model_output) + noise_norm = jnp.linalg.norm(noise) + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * jnp.ones(sample.shape[0]) + + # compute corrected sample: model_output term and noise term + step_size = step_size.flatten() + step_size = broadcast_to_shape_from_left(step_size, sample.shape) + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise + + if not return_dict: + return (prev_sample, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_sde_vp.py b/diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e2ead90edb57cd1eb1d270695e222d404064180 --- /dev/null +++ b/diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,90 @@ +# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +import math +from typing import Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + """ + The variance preserving stochastic differential equation (SDE) scheduler. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + UNDER CONSTRUCTION + + """ + + order = 1 + + @register_to_config + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) + + def step_pred(self, score, x, t, generator=None): + if self.timesteps is None: + raise ValueError( + "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # TODO(Patrick) better comments + non-PyTorch + # postprocess model score + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + std = std.flatten() + while len(std.shape) < len(score.shape): + std = std.unsqueeze(-1) + score = -score / std + + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) + beta_t = beta_t.flatten() + while len(beta_t.shape) < len(x.shape): + beta_t = beta_t.unsqueeze(-1) + drift = -0.5 * beta_t * x + + diffusion = torch.sqrt(beta_t) + drift = drift - diffusion**2 * score + x_mean = x + drift * dt + + # add noise + noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) + x = x_mean + diffusion * math.sqrt(-dt) * noise + + return x, x_mean + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_unclip.py b/diffusers/schedulers/scheduling_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..fd23e48bad00d16a1086f31b6584ff9df03129fb --- /dev/null +++ b/diffusers/schedulers/scheduling_unclip.py @@ -0,0 +1,348 @@ +# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, randn_tensor +from .scheduling_utils import SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP +class UnCLIPSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class UnCLIPScheduler(SchedulerMixin, ConfigMixin): + """ + NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This + scheduler will be removed and replaced with DDPM. + + This is a modified DDPM Scheduler specifically for the karlo unCLIP model. + + This scheduler has some minor variations in how it calculates the learned range variance and dynamically + re-calculates betas based off the timesteps it is skipping. + + The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. + + See [`~DDPMScheduler`] for more information on DDPM scheduling + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` + or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical + stability. + clip_sample_range (`float`, default `1.0`): + The range to clip the sample between. See `clip_sample`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) + or `sample` (directly predicting the noisy sample`) + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + variance_type: str = "fixed_small_log", + clip_sample: bool = True, + clip_sample_range: Optional[float] = 1.0, + prediction_type: str = "epsilon", + beta_schedule: str = "squaredcos_cap_v2", + ): + if beta_schedule != "squaredcos_cap_v2": + raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") + + self.betas = betas_for_alpha_bar(num_train_timesteps) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The + different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy + of the results. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): + if prev_timestep is None: + prev_timestep = t - 1 + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if prev_timestep == t - 1: + beta = self.betas[t] + else: + beta = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = beta_prod_t_prev / beta_prod_t * beta + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small_log": + variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) + elif variance_type == "learned_range": + # NOTE difference with DDPM scheduler + min_log = variance.log() + max_log = beta.log() + + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + prev_timestep: Optional[int] = None, + generator=None, + return_dict: bool = True, + ) -> Union[UnCLIPSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. + Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + if prev_timestep is None: + prev_timestep = t - 1 + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if prev_timestep == t - 1: + beta = self.betas[t] + alpha = self.alphas[t] + else: + beta = 1 - alpha_prod_t / alpha_prod_t_prev + alpha = 1 - beta + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" + " for the UnCLIPScheduler." + ) + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t + current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + variance_noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device + ) + + variance = self._get_variance( + t, + predicted_variance=predicted_variance, + prev_timestep=prev_timestep, + ) + + if self.variance_type == "fixed_small_log": + variance = variance + elif self.variance_type == "learned_range": + variance = (0.5 * variance).exp() + else: + raise ValueError( + f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" + " for the UnCLIPScheduler." + ) + + variance = variance * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples diff --git a/diffusers/schedulers/scheduling_unipc_multistep.py b/diffusers/schedulers/scheduling_unipc_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..3caa01a58562f5f12d46354ef6112a64875da79d --- /dev/null +++ b/diffusers/schedulers/scheduling_unipc_multistep.py @@ -0,0 +1,681 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a + corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is + by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can + also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied + after any off-the-shelf solvers to increase the order of accuracy. + + For more details, see the original paper: https://arxiv.org/abs/2302.04867 + + Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend + to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note + that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of + accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, + and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the + dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models + (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, default `True`): + whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details + solver_type (`str`, default `bh2`): + the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + decide which step to disable the corrector. For large guidance scale, the misalignment between the + `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated + by disable the corrector at the first few steps (e.g., disable_corrector=[0]) + solver_p (`SchedulerMixin`, default `None`): + can be any other scheduler. If specified, the algorithm will become solver_p + UniC. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + r""" + Convert the model output to the corresponding type that the algorithm PC needs. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + if self.predict_x0: + if self.config.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.FloatTensor`): + direct outputs from learned diffusion model at the current timestep. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + order (`int`): the order of UniP at this step, also the p in UniPC-p. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = self.timestep_list[-1], prev_timestep + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + this_sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.FloatTensor`): the model outputs at `x_t` + this_timestep (`int`): the current timestep `t` + last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` + this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` + order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy + should be order + 1 + + Returns: + `torch.FloatTensor`: the corrected sample tensor at the current timestep. + """ + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = timestep_list[-1], this_timestep + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep UniPC. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = ( + step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + # now prepare to run the predictor + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + prev_timestep=prev_timestep, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/diffusers/schedulers/scheduling_utils.py b/diffusers/schedulers/scheduling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f95beb022ac042b6e1ef588a72365b2623338de --- /dev/null +++ b/diffusers/schedulers/scheduling_utils.py @@ -0,0 +1,177 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional, Union + +import torch + +from ..utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +# NOTE: We make this type an enum because it simplifies usage in docs and prevents +# circular imports when used for `_compatibles` within the schedulers module. +# When it's used as a type in pipelines, it really is a Union because the actual +# scheduler instance is passed in. +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class SchedulerMixin: + """ + Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing the schedluer configurations saved using + [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs, commit_hash = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes diff --git a/diffusers/schedulers/scheduling_utils_flax.py b/diffusers/schedulers/scheduling_utils_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..19ce5b8360b9be5bb4b4ec46fbeac0715d6b5869 --- /dev/null +++ b/diffusers/schedulers/scheduling_utils_flax.py @@ -0,0 +1,284 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import math +import os +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..utils import BaseOutput + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +# NOTE: We make this type an enum because it simplifies usage in docs and prevents +# circular imports when used for `_compatibles` within the schedulers module. +# When it's used as a type in pipelines, it really is a Union because the actual +# scheduler instance is passed in. +class FlaxKarrasDiffusionSchedulers(Enum): + FlaxDDIMScheduler = 1 + FlaxDDPMScheduler = 2 + FlaxPNDMScheduler = 3 + FlaxLMSDiscreteScheduler = 4 + FlaxDPMSolverMultistepScheduler = 5 + + +@dataclass +class FlaxSchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: jnp.ndarray + + +class FlaxSchedulerMixin: + """ + Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + ignore_for_config = ["dtype"] + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], + e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) + + if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): + state = scheduler.create_state() + + if return_unused_kwargs: + return scheduler, state, unused_kwargs + + return scheduler, state + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~FlaxSchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes + + +def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: + assert len(shape) >= x.ndim + return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=dtype) + + +@flax.struct.dataclass +class CommonSchedulerState: + alphas: jnp.ndarray + betas: jnp.ndarray + alphas_cumprod: jnp.ndarray + + @classmethod + def create(cls, scheduler): + config = scheduler.config + + if config.trained_betas is not None: + betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) + elif config.beta_schedule == "linear": + betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) + elif config.beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + betas = ( + jnp.linspace( + config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype + ) + ** 2 + ) + elif config.beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) + else: + raise NotImplementedError( + f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" + ) + + alphas = 1.0 - betas + + alphas_cumprod = jnp.cumprod(alphas, axis=0) + + return cls( + alphas=alphas, + betas=betas, + alphas_cumprod=alphas_cumprod, + ) + + +def get_sqrt_alpha_prod( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + alphas_cumprod = state.alphas_cumprod + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) + + return sqrt_alpha_prod, sqrt_one_minus_alpha_prod + + +def add_noise_common( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + +def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): + sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/diffusers/schedulers/scheduling_vq_diffusion.py b/diffusers/schedulers/scheduling_vq_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b92722e4d462ca675bbf11230c1c39810de48b6e --- /dev/null +++ b/diffusers/schedulers/scheduling_vq_diffusion.py @@ -0,0 +1,496 @@ +# Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class VQDiffusionSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.LongTensor + + +def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: + """ + Convert batch of vector of class indices into batch of log onehot vectors + + Args: + x (`torch.LongTensor` of shape `(batch size, vector length)`): + Batch of class indices + + num_classes (`int`): + number of classes to be used for the onehot vectors + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: + Log onehot vectors + """ + x_onehot = F.one_hot(x, num_classes) + x_onehot = x_onehot.permute(0, 2, 1) + log_x = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_x + + +def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: + """ + Apply gumbel noise to `logits` + """ + uniform = torch.rand(logits.shape, device=logits.device, generator=generator) + gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) + noised = gumbel_noise + logits + return noised + + +def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): + """ + Cumulative and non-cumulative alpha schedules. + + See section 4.1. + """ + att = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + + alpha_cum_start + ) + att = np.concatenate(([1], att)) + at = att[1:] / att[:-1] + att = np.concatenate((att[1:], [1])) + return at, att + + +def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): + """ + Cumulative and non-cumulative gamma schedules. + + See section 4.1. + """ + ctt = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + + gamma_cum_start + ) + ctt = np.concatenate(([0], ctt)) + one_minus_ctt = 1 - ctt + one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] + ct = 1 - one_minus_ct + ctt = np.concatenate((ctt[1:], [0])) + return ct, ctt + + +class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. + + The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous + diffusion timestep. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2111.14822 + + Args: + num_vec_classes (`int`): + The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked + latent pixel. + + num_train_timesteps (`int`): + Number of diffusion steps used to train the model. + + alpha_cum_start (`float`): + The starting cumulative alpha value. + + alpha_cum_end (`float`): + The ending cumulative alpha value. + + gamma_cum_start (`float`): + The starting cumulative gamma value. + + gamma_cum_end (`float`): + The ending cumulative gamma value. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_vec_classes: int, + num_train_timesteps: int = 100, + alpha_cum_start: float = 0.99999, + alpha_cum_end: float = 0.000009, + gamma_cum_start: float = 0.000009, + gamma_cum_end: float = 0.99999, + ): + self.num_embed = num_vec_classes + + # By convention, the index for the mask class is the last class index + self.mask_class = self.num_embed - 1 + + at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) + ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) + + num_non_mask_classes = self.num_embed - 1 + bt = (1 - at - ct) / num_non_mask_classes + btt = (1 - att - ctt) / num_non_mask_classes + + at = torch.tensor(at.astype("float64")) + bt = torch.tensor(bt.astype("float64")) + ct = torch.tensor(ct.astype("float64")) + log_at = torch.log(at) + log_bt = torch.log(bt) + log_ct = torch.log(ct) + + att = torch.tensor(att.astype("float64")) + btt = torch.tensor(btt.astype("float64")) + ctt = torch.tensor(ctt.astype("float64")) + log_cumprod_at = torch.log(att) + log_cumprod_bt = torch.log(btt) + log_cumprod_ct = torch.log(ctt) + + self.log_at = log_at.float() + self.log_bt = log_bt.float() + self.log_ct = log_ct.float() + self.log_cumprod_at = log_cumprod_at.float() + self.log_cumprod_bt = log_cumprod_bt.float() + self.log_cumprod_ct = log_cumprod_ct.float() + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + device (`str` or `torch.device`): + device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.log_at = self.log_at.to(device) + self.log_bt = self.log_bt.to(device) + self.log_ct = self.log_ct.to(device) + self.log_cumprod_at = self.log_cumprod_at.to(device) + self.log_cumprod_bt = self.log_cumprod_bt.to(device) + self.log_cumprod_ct = self.log_cumprod_ct.to(device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.long, + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[VQDiffusionSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the + docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + t (`torch.long`): + The timestep that determines which transition matrices are used. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + generator: (`torch.Generator` or None): + RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. + + return_dict (`bool`): + option for returning tuple rather than VQDiffusionSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + if timestep == 0: + log_p_x_t_min_1 = model_output + else: + log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) + + log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) + + x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) + + if not return_dict: + return (x_t_min_1,) + + return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) + + def q_posterior(self, log_p_x_0, x_t, t): + """ + Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + + Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only + forward probabilities. + + Equation (11) stated in terms of forward probabilities via Equation (5): + + Where: + - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) + + p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + t (torch.Long): + The timestep that determines which transition matrix is used. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: + The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + """ + log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) + + log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True + ) + + log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False + ) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q = log_p_x_0 - log_q_x_t_given_x_0 + + # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , + # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n + q = q - q_log_sum_exp + + # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # . . . + # . . . + # . . . + # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # c_cumulative_{t-1} ... c_cumulative_{t-1} + q = self.apply_cumulative_transitions(q, t - 1) + + # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n + # . . . + # . . . + # . . . + # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n + # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 + log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp + + # For each column, there are two possible cases. + # + # Where: + # - sum(p_n(x_0))) is summing over all classes for x_0 + # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) + # - C_j is the class transitioning to + # + # 1. x_t is masked i.e. x_t = c_k + # + # Simplifying the expression, the column vector is: + # . + # . + # . + # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) + # . + # . + # . + # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) + # + # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. + # + # For the other rows, we can state the equation as ... + # + # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] + # + # This verifies the other rows. + # + # 2. x_t is not masked + # + # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: + # . + # . + # . + # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # 0 + # + # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. + return log_p_x_t_min_1 + + def log_Q_t_transitioning_to_known_class( + self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool + ): + """ + Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each + latent pixel in `x_t`. + + See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix + is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. + + Args: + t (torch.Long): + The timestep that determines which transition matrix is used. + + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + + log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): + The log one-hot vectors of `x_t` + + cumulative (`bool`): + If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, + we use the cumulative transition matrix `0`->`t`. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: + Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability + transition matrix. + + When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be + masked. + + Where: + - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. + - C_0 is a class of a latent pixel embedding + - C_k is the class of the masked latent pixel + + non-cumulative result (omitting logarithms): + ``` + q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) + . . . + . . . + . . . + q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) + ``` + + cumulative result (omitting logarithms): + ``` + q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) + . . . + . . . + . . . + q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) + ``` + """ + if cumulative: + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + else: + a = self.log_at[t] + b = self.log_bt[t] + c = self.log_ct[t] + + if not cumulative: + # The values in the onehot vector can also be used as the logprobs for transitioning + # from masked latent pixels. If we are not calculating the cumulative transitions, + # we need to save these vectors to be re-appended to the final matrix so the values + # aren't overwritten. + # + # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector + # if x_t is not masked + # + # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector + # if x_t is masked + log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) + + # `index_to_log_onehot` will add onehot vectors for masked pixels, + # so the default one hot matrix has one too many rows. See the doc string + # for an explanation of the dimensionality of the returned matrix. + log_onehot_x_t = log_onehot_x_t[:, :-1, :] + + # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. + # + # Don't worry about what values this sets in the columns that mark transitions + # to masked latent pixels. They are overwrote later with the `mask_class_mask`. + # + # Looking at the below logspace formula in non-logspace, each value will evaluate to either + # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column + # or + # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. + # + # See equation 7 for more details. + log_Q_t = (log_onehot_x_t + a).logaddexp(b) + + # The whole column of each masked pixel is `c` + mask_class_mask = x_t == self.mask_class + mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) + log_Q_t[mask_class_mask] = c + + if not cumulative: + log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) + + return log_Q_t + + def apply_cumulative_transitions(self, q, t): + bsz = q.shape[0] + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + + num_latent_pixels = q.shape[2] + c = c.expand(bsz, 1, num_latent_pixels) + + q = (q + a).logaddexp(b) + q = torch.cat((q, c), dim=1) + + return q diff --git a/diffusers/training_utils.py b/diffusers/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa9ed64554bf8830e35efd220a77bd2de207f18 --- /dev/null +++ b/diffusers/training_utils.py @@ -0,0 +1,314 @@ +import contextlib +import copy +import random +from typing import Any, Dict, Iterable, Optional, Union + +import numpy as np +import torch + +from .utils import deprecate, is_transformers_available + + +if is_transformers_available(): + import transformers + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, + **kwargs, + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) + + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = 0 + self.cur_decay_value = None # set in `step()` + + self.model_cls = model_cls + self.model_config = model_config + + @classmethod + def from_pretrained(cls, path, model_cls) -> "EMAModel": + _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + model = model_cls.from_pretrained(path) + + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + + ema_model.load_state_dict(ema_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config is None: + raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") + + model = self.model_cls.from_config(self.model_config) + state_dict = self.state_dict() + state_dict.pop("shadow_params", None) + + model.register_to_config(**state_dict) + self.copy_to(model.parameters()) + model.save_pretrained(path) + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + context_manager = contextlib.nullcontext + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + + for s_param, param in zip(self.shadow_params, parameters): + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") diff --git a/diffusers/utils/__init__.py b/diffusers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98fac64497e7efa4a881124dd778c4f3084402e8 --- /dev/null +++ b/diffusers/utils/__init__.py @@ -0,0 +1,122 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from packaging import version + +from .. import __version__ +from .accelerate_utils import apply_forward_hook +from .constants import ( + CONFIG_NAME, + DEPRECATED_REVISION_ARGS, + DIFFUSERS_CACHE, + DIFFUSERS_DYNAMIC_MODULE_NAME, + FLAX_WEIGHTS_NAME, + HF_MODULES_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + ONNX_EXTERNAL_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, +) +from .deprecation_utils import deprecate +from .doc_utils import replace_example_docstring +from .dynamic_modules_utils import get_class_from_dynamic_module +from .hub_utils import ( + HF_HUB_OFFLINE, + _add_variant, + _get_model_file, + extract_commit_hash, + http_user_agent, +) +from .import_utils import ( + BACKENDS_MAPPING, + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + USE_JAX, + USE_TF, + USE_TORCH, + DummyObject, + OptionalDependencyNotAvailable, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_flax_available, + is_ftfy_available, + is_inflect_available, + is_invisible_watermark_available, + is_k_diffusion_available, + is_k_diffusion_version, + is_librosa_available, + is_note_seq_available, + is_omegaconf_available, + is_onnx_available, + is_safetensors_available, + is_scipy_available, + is_tensorboard_available, + is_tf_available, + is_torch_available, + is_torch_version, + is_torchsde_available, + is_transformers_available, + is_transformers_version, + is_unidecode_available, + is_wandb_available, + is_xformers_available, + requires_backends, +) +from .logging import get_logger +from .outputs import BaseOutput +from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil +from .torch_utils import is_compiled_module, randn_tensor + + +if is_torch_available(): + from .testing_utils import ( + floats_tensor, + load_hf_numpy, + load_image, + load_numpy, + load_pt, + nightly, + parse_flag_from_env, + print_tensor_test, + require_torch_2, + require_torch_gpu, + skip_mps, + slow, + torch_all_close, + torch_device, + ) + from .torch_utils import maybe_allow_in_graph + +from .testing_utils import export_to_gif, export_to_video + + +logger = get_logger(__name__) + + +def check_min_version(min_version): + if version.parse(__version__) < version.parse(min_version): + if "dev" in min_version: + error_message = ( + "This example requires a source install from HuggingFace diffusers (see " + "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," + ) + else: + error_message = f"This example requires a minimum version of {min_version}," + error_message += f" but the version found is {__version__}.\n" + raise ImportError(error_message) diff --git a/diffusers/utils/__pycache__/__init__.cpython-310.pyc b/diffusers/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e310135040760e678de97f01b045e5b456e2892f Binary files /dev/null and b/diffusers/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/__init__.cpython-38.pyc b/diffusers/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed1374b0d9691344cd9a04cae2ef91d26423a79 Binary files /dev/null and b/diffusers/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc b/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10d05a0d26ad25743da806d07a163c95c3880789 Binary files /dev/null and b/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/accelerate_utils.cpython-38.pyc b/diffusers/utils/__pycache__/accelerate_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8dab26b80c017ad19e911dc0890f7dd85d68417 Binary files /dev/null and b/diffusers/utils/__pycache__/accelerate_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/constants.cpython-310.pyc b/diffusers/utils/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e980e4d8349a7c3ea515179fceff00ca1b2a1ee8 Binary files /dev/null and b/diffusers/utils/__pycache__/constants.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/constants.cpython-38.pyc b/diffusers/utils/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76bdcbfd60909c4a3303eb95f69fb0e7ec435293 Binary files /dev/null and b/diffusers/utils/__pycache__/constants.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc b/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df4b7753e7404d9cc591e339b470637a2917083f Binary files /dev/null and b/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/deprecation_utils.cpython-38.pyc b/diffusers/utils/__pycache__/deprecation_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de58ba991d9ae78013c1553d731bf03eceef760a Binary files /dev/null and b/diffusers/utils/__pycache__/deprecation_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc b/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f129c84c0fbbac511ae7533cef7b5be4056350b Binary files /dev/null and b/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/doc_utils.cpython-38.pyc b/diffusers/utils/__pycache__/doc_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28f4c6ec6bba7a5d925f5d32bd115faf79f1943f Binary files /dev/null and b/diffusers/utils/__pycache__/doc_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba514a0e2bf5d2980575b28a242d7599aaa6bdb5 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cef6bd6fdbbcaedf57f4dc2c181861fab7a614db Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e68c7855ea036ae74664d258b3ff82db52d9a4 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_flax_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1316c453f9831487ced9d28fc940060524c0161 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcf2cff2c2a1b5fff7243c697ff3353659f534af Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f632fbc289333a27a2799923f0e7b60697e9715c Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72149f1f4a5b1f34f42febc8821681a82d72bc30 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c93d3c999f715c5afc190d9baf886dc299c90770 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fad3b97c7874d7193b90dda87912a25e3125e07 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d85bc5dec6fb476bf950a08fe0dd7f3489e467 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10f5a6e9146ff91bab5fefcca698cb085f8ba6f6 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..442526b89d7d4aeaaaa2552df0c427132070b18d Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_invisible_watermark_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_invisible_watermark_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a028ce8b900009303ca9eae3e8023a986bf97519 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_invisible_watermark_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e20efeffe6da968d906a44cae8a341ca1aceb0b Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d88c82928776cf41bc2c36f72939a665ac2c4b3f Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987ccbaa93fdf4a23aaeaa1d1521d7c7c96e83ef Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e35e45dabcf2437f7800bebd4d6d781da9c367d1 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..157e64060d57b362f1f6013262a64f8bf1bd37bc Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe8d329eec939b64e5ffeca2a939beb5e3d8d990 Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caf5da69df96565b1a708a718e2c1a8de07970ab Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d7927f515b35ab0785380733985c90ff47ed3e1 Binary files /dev/null and b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-38.pyc b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c4b66fe6c3b949f0d3add87a4a13cb42aa5a5d Binary files /dev/null and b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc b/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..509d816b91cdcaa9f8293c73d938994d7b2d4442 Binary files /dev/null and b/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/hub_utils.cpython-38.pyc b/diffusers/utils/__pycache__/hub_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d85317db4cb3fb29fce9981f3c34c4940a039b Binary files /dev/null and b/diffusers/utils/__pycache__/hub_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/import_utils.cpython-310.pyc b/diffusers/utils/__pycache__/import_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74b8269dff3936b8b5f441bc8e02b70ca2a12363 Binary files /dev/null and b/diffusers/utils/__pycache__/import_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/import_utils.cpython-38.pyc b/diffusers/utils/__pycache__/import_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8eac94029ce0f02690721de3e96a431bdb422ff Binary files /dev/null and b/diffusers/utils/__pycache__/import_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/logging.cpython-310.pyc b/diffusers/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f071c47dcbc9316b467b4010e4fe33015f95d93 Binary files /dev/null and b/diffusers/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/logging.cpython-38.pyc b/diffusers/utils/__pycache__/logging.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e29dc506db303bf13c90171387c3f119f0fb7f3 Binary files /dev/null and b/diffusers/utils/__pycache__/logging.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/outputs.cpython-310.pyc b/diffusers/utils/__pycache__/outputs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2acb989510b087333ed6bc924e0df165bfe1ae1 Binary files /dev/null and b/diffusers/utils/__pycache__/outputs.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/outputs.cpython-38.pyc b/diffusers/utils/__pycache__/outputs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8df1d6a0faf156add97ecd8c7238f40d886d7982 Binary files /dev/null and b/diffusers/utils/__pycache__/outputs.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc b/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b64bee2f2ab9152d2ce1b1018af376553ceedc Binary files /dev/null and b/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/pil_utils.cpython-38.pyc b/diffusers/utils/__pycache__/pil_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b655c24119aab52e8b361b6084c91ce8f5b9926 Binary files /dev/null and b/diffusers/utils/__pycache__/pil_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc b/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d703eef3788601acccc5c34c642d16d51fc875db Binary files /dev/null and b/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/testing_utils.cpython-38.pyc b/diffusers/utils/__pycache__/testing_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fa118df35a47eb9424e41a7ee687586ea5adb65 Binary files /dev/null and b/diffusers/utils/__pycache__/testing_utils.cpython-38.pyc differ diff --git a/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc b/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e119de6e598661628ff58b395c8e215d096f97 Binary files /dev/null and b/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc differ diff --git a/diffusers/utils/__pycache__/torch_utils.cpython-38.pyc b/diffusers/utils/__pycache__/torch_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e907758a309392e44565a1f44f19a7726f99a833 Binary files /dev/null and b/diffusers/utils/__pycache__/torch_utils.cpython-38.pyc differ diff --git a/diffusers/utils/accelerate_utils.py b/diffusers/utils/accelerate_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..10a83e1dd209cca198f4038d0d7e7228f9671859 --- /dev/null +++ b/diffusers/utils/accelerate_utils.py @@ -0,0 +1,48 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Accelerate utilities: Utilities related to accelerate +""" + +from packaging import version + +from .import_utils import is_accelerate_available + + +if is_accelerate_available(): + import accelerate + + +def apply_forward_hook(method): + """ + Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful + for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the + appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. + + This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. + + :param method: The method to decorate. This method should be a method of a PyTorch module. + """ + if not is_accelerate_available(): + return method + accelerate_version = version.parse(accelerate.__version__).base_version + if version.parse(accelerate_version) < version.parse("0.17.0"): + return method + + def wrapper(self, *args, **kwargs): + if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): + self._hf_hook.pre_forward(self) + return method(self, *args, **kwargs) + + return wrapper diff --git a/diffusers/utils/constants.py b/diffusers/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e60a2a873b29a7d3adffbd7179be1670b3b417 --- /dev/null +++ b/diffusers/utils/constants.py @@ -0,0 +1,32 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home + + +default_cache_path = HUGGINGFACE_HUB_CACHE + + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +ONNX_WEIGHTS_NAME = "model.onnx" +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +DIFFUSERS_CACHE = default_cache_path +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) +DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] diff --git a/diffusers/utils/deprecation_utils.py b/diffusers/utils/deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f482deddd2f46b8d2e29d5229faa0e9a21f2fd98 --- /dev/null +++ b/diffusers/utils/deprecation_utils.py @@ -0,0 +1,49 @@ +import inspect +import warnings +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/diffusers/utils/doc_utils.py b/diffusers/utils/doc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1f87743f99802931334bd51bf99985775116d59 --- /dev/null +++ b/diffusers/utils/doc_utils.py @@ -0,0 +1,38 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Doc utilities: Utilities related to documentation +""" +import re + + +def replace_example_docstring(example_docstring): + def docstring_decorator(fn): + func_doc = fn.__doc__ + lines = func_doc.split("\n") + i = 0 + while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: + i += 1 + if i < len(lines): + lines[i] = example_docstring + func_doc = "\n".join(lines) + else: + raise ValueError( + f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " + f"current docstring is:\n{func_doc}" + ) + fn.__doc__ = func_doc + return fn + + return docstring_decorator diff --git a/diffusers/utils/dummy_flax_and_transformers_objects.py b/diffusers/utils/dummy_flax_and_transformers_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..162bac1c4331149c4b5abde1eadd8013ab0cda99 --- /dev/null +++ b/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -0,0 +1,62 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + +class FlaxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) diff --git a/diffusers/utils/dummy_flax_objects.py b/diffusers/utils/dummy_flax_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb80d136f338d193c67773266355956afd1d98a --- /dev/null +++ b/diffusers/utils/dummy_flax_objects.py @@ -0,0 +1,197 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class FlaxControlNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxModelMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAutoencoderKL(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDDIMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDDPMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxKarrasVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxPNDMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxSchedulerMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) diff --git a/diffusers/utils/dummy_note_seq_objects.py b/diffusers/utils/dummy_note_seq_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..c02d0b015aedc37c01fb3b843bc79547aae5da68 --- /dev/null +++ b/diffusers/utils/dummy_note_seq_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class MidiProcessor(metaclass=DummyObject): + _backends = ["note_seq"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["note_seq"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["note_seq"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["note_seq"]) diff --git a/diffusers/utils/dummy_onnx_objects.py b/diffusers/utils/dummy_onnx_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..bde5f6ad0793e2d81bc638600b46ff81748d09ee --- /dev/null +++ b/diffusers/utils/dummy_onnx_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class OnnxRuntimeModel(metaclass=DummyObject): + _backends = ["onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["onnx"]) diff --git a/diffusers/utils/dummy_pt_objects.py b/diffusers/utils/dummy_pt_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..b955ec5320de4973a4e8eaf2f039953189d83228 --- /dev/null +++ b/diffusers/utils/dummy_pt_objects.py @@ -0,0 +1,810 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AutoencoderKL(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModelMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class MultiAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PriorTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class T2IAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class T5FilmDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet1DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UNet3DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class VQModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +def get_constant_schedule(*args, **kwargs): + requires_backends(get_constant_schedule, ["torch"]) + + +def get_constant_schedule_with_warmup(*args, **kwargs): + requires_backends(get_constant_schedule_with_warmup, ["torch"]) + + +def get_cosine_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_schedule_with_warmup, ["torch"]) + + +def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) + + +def get_linear_schedule_with_warmup(*args, **kwargs): + requires_backends(get_linear_schedule_with_warmup, ["torch"]) + + +def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): + requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) + + +def get_scheduler(*args, **kwargs): + requires_backends(get_scheduler, ["torch"]) + + +class AudioPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DanceDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DiTPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ImagePipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KarrasVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LDMSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PNDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class RePaintPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ScoreSdeVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class CMStochasticIterativeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMParallelScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMParallelScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DEISMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DPMSolverSinglestepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerAncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HeunDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class IPNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KarrasVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KDPM2DiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class RePaintScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SchedulerMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UnCLIPScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class UniPCMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class VQDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EMAModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) diff --git a/diffusers/utils/dummy_torch_and_librosa_objects.py b/diffusers/utils/dummy_torch_and_librosa_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..2088bc4a744198284f22fe54e6f1055cf3568566 --- /dev/null +++ b/diffusers/utils/dummy_torch_and_librosa_objects.py @@ -0,0 +1,32 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AudioDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "librosa"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "librosa"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) + + +class Mel(metaclass=DummyObject): + _backends = ["torch", "librosa"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "librosa"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "librosa"]) diff --git a/diffusers/utils/dummy_torch_and_scipy_objects.py b/diffusers/utils/dummy_torch_and_scipy_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ff25863822b04971d2c6dfdc17f5b28774cf05 --- /dev/null +++ b/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class LMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch", "scipy"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "scipy"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "scipy"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "scipy"]) diff --git a/diffusers/utils/dummy_torch_and_torchsde_objects.py b/diffusers/utils/dummy_torch_and_torchsde_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..a81bbb316f32267c31b06598519f1eef9ddde643 --- /dev/null +++ b/diffusers/utils/dummy_torch_and_torchsde_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class DPMSolverSDEScheduler(metaclass=DummyObject): + _backends = ["torch", "torchsde"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "torchsde"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) diff --git a/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4d59c537fa3dbaa76699836d9024823c4e36bd --- /dev/null +++ b/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py @@ -0,0 +1,62 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + +class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + +class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + +class StableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) diff --git a/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..56836f0b6d77b8daa25e956101694863e418339f --- /dev/null +++ b/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "k_diffusion"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "k_diffusion"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "k_diffusion"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "k_diffusion"]) diff --git a/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..b7afad8226b87292100270e3e7daad6885be0e7f --- /dev/null +++ b/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -0,0 +1,92 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + +class StableDiffusionOnnxPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) diff --git a/diffusers/utils/dummy_torch_and_transformers_objects.py b/diffusers/utils/dummy_torch_and_transformers_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..016760337c696890e1ec7033c00129d195d8f6a3 --- /dev/null +++ b/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -0,0 +1,962 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AltDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AudioLDMPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class CycleDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFInpaintingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ImageTextPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyPriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22ControlnetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22Img2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22InpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyV22PriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LDMTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PaintByExamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SemanticStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ShapEImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ShapEPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionAdapterPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionDiffEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionLDM3DPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionModelEditingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPanoramaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionParadigmsPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPipelineSafe(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionSAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableUnCLIPPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class TextToVideoSDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class TextToVideoZeroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UnCLIPImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UnCLIPPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserTextDecoder(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VideoToVideoSDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VQDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) diff --git a/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py b/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..fbde04e33f0abd86d12f3dee048a4f0585c9f19d --- /dev/null +++ b/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class SpectrogramDiffusionPipeline(metaclass=DummyObject): + _backends = ["transformers", "torch", "note_seq"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "torch", "note_seq"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["transformers", "torch", "note_seq"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["transformers", "torch", "note_seq"]) diff --git a/diffusers/utils/dynamic_modules_utils.py b/diffusers/utils/dynamic_modules_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0952f0b514cb52e63fdac8a780ddc9482a5b9d --- /dev/null +++ b/diffusers/utils/dynamic_modules_utils.py @@ -0,0 +1,456 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import importlib +import inspect +import json +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union +from urllib import request + +from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info +from packaging import version + +from .. import __version__ +from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging + + +COMMUNITY_PIPELINES_URL = ( + "https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py" +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_diffusers_versions(): + url = "https://pypi.org/pypi/diffusers/json" + releases = json.loads(request.urlopen(url).read())["releases"].keys() + return sorted(releases, key=lambda x: version.Version(x)) + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def get_relative_imports(module_file): + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file): + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + + if class_name is None: + return find_pipeline_class(module) + return getattr(module, class_name) + + +def find_pipeline_class(loaded_module): + """ + Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class + inheriting from `DiffusionPipeline`. + """ + from ..pipelines import DiffusionPipeline + + cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) + + pipeline_class = None + for cls_name, cls in cls_members.items(): + if ( + cls_name != DiffusionPipeline.__name__ + and issubclass(cls, DiffusionPipeline) + and cls.__module__.split(".")[0] != "diffusers" + ): + if pipeline_class is not None: + raise ValueError( + f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" + f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" + f" {loaded_module}." + ) + pipeline_class = cls + + return pipeline_class + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private + or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Returns: + `str`: The path to the module inside the cache. + """ + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + + if os.path.isfile(module_file_or_url): + resolved_module_file = module_file_or_url + submodule = "local" + elif pretrained_model_name_or_path.count("/") == 0: + available_versions = get_diffusers_versions() + # cut ".dev0" + latest_version = "v" + ".".join(__version__.split(".")[:3]) + + # retrieve github version that matches + if revision is None: + revision = latest_version if latest_version[1:] in available_versions else "main" + logger.info(f"Defaulting to latest_version: {revision}.") + elif revision in available_versions: + revision = f"v{revision}" + elif revision == "main": + revision = revision + else: + raise ValueError( + f"`custom_revision`: {revision} does not exist. Please make sure to choose one of" + f" {', '.join(available_versions + ['main'])}." + ) + + # community pipeline on GitHub + github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path) + try: + resolved_module_file = cached_download( + github_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=False, + ) + submodule = "git" + module_file = pretrained_model_name_or_path + ".py" + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + else: + try: + # Load from URL or cache if already cached + resolved_module_file = hf_hub_download( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == "local" or submodule == "git": + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + else: + # Get the commit hash + # TODO: we will get this info in the etag soon, so retrieve it from there and not here. + if isinstance(use_auth_token, str): + token = use_auth_token + elif use_auth_token is True: + token = HfFolder.get_token() + else: + token = None + + commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha + + # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the + # benefit of versioning. + submodule_path = submodule_path / commit_hash + full_submodule = full_submodule + os.path.sep + commit_hash + create_dynamic_module(full_submodule) + + if not (submodule_path / module_file).exists(): + shutil.copy(resolved_module_file, submodule_path / module_file) + # Make sure we also have every file with relative + for module_needed in modules_needed: + if not (submodule_path / module_needed).exists(): + get_cached_module_file( + pretrained_model_name_or_path, + f"{module_needed}.py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: Optional[str] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private + or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Returns: + `type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + ```""" + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/diffusers/utils/hub_utils.py b/diffusers/utils/hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f0cf00a5c5d0d303ba53f62fbf027c0bc31ad49 --- /dev/null +++ b/diffusers/utils/hub_utils.py @@ -0,0 +1,361 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re +import sys +import traceback +import warnings +from pathlib import Path +from typing import Dict, Optional, Union +from uuid import uuid4 + +from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami +from huggingface_hub.file_download import REGEX_COMMIT_HASH +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + is_jinja_available, +) +from packaging import version +from requests import HTTPError + +from .. import __version__ +from .constants import ( + DEPRECATED_REVISION_ARGS, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, +) +from .import_utils import ( + ENV_VARS_TRUE_VALUES, + _flax_version, + _jax_version, + _onnxruntime_version, + _torch_version, + is_flax_available, + is_onnx_available, + is_torch_available, +) +from .logging import get_logger + + +logger = get_logger(__name__) + + +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" +SESSION_ID = uuid4().hex +HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES +DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES +HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/" + + +def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: + """ + Formats a user-agent string with basic info about a request. + """ + ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" + if DISABLE_TELEMETRY or HF_HUB_OFFLINE: + return ua + "; telemetry/off" + if is_torch_available(): + ua += f"; torch/{_torch_version}" + if is_flax_available(): + ua += f"; jax/{_jax_version}" + ua += f"; flax/{_flax_version}" + if is_onnx_available(): + ua += f"; onnxruntime/{_onnxruntime_version}" + # CI will set this value to True + if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: + ua += "; is_ci/true" + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + return ua + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def create_model_card(args, model_name): + if not is_jinja_available(): + raise ValueError( + "Modelcard rendering is based on Jinja templates." + " Please make sure to have `jinja` installed before using `create_model_card`." + " To install it, please run `pip install Jinja2`." + ) + + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: + return + + hub_token = args.hub_token if hasattr(args, "hub_token") else None + repo_name = get_full_repo_name(model_name, token=hub_token) + + model_card = ModelCard.from_template( + card_data=ModelCardData( # Card metadata object that will be converted to YAML block + language="en", + license="apache-2.0", + library_name="diffusers", + tags=[], + datasets=args.dataset_name, + metrics=[], + ), + template_path=MODEL_CARD_TEMPLATE_PATH, + model_name=model_name, + repo_name=repo_name, + dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, + learning_rate=args.learning_rate, + train_batch_size=args.train_batch_size, + eval_batch_size=args.eval_batch_size, + gradient_accumulation_steps=( + args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None + ), + adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, + adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, + adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, + adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, + lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, + lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, + ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, + ema_power=args.ema_power if hasattr(args, "ema_power") else None, + ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, + mixed_precision=args.mixed_precision, + ) + + card_path = os.path.join(args.output_dir, "README.md") + model_card.save(card_path) + + +def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + resolved_file = str(Path(resolved_file).as_posix()) + search = re.search(r"snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + +# Old default cache path, potentially to be migrated. +# This logic was more or less taken from `transformers`, with the following differences: +# - Diffusers doesn't use custom environment variables to specify the cache path. +# - There is no need to migrate the cache format, just move the files to the new location. +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") + + +def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: + if new_cache_dir is None: + new_cache_dir = DIFFUSERS_CACHE + if old_cache_dir is None: + old_cache_dir = old_diffusers_cache + + old_cache_dir = Path(old_cache_dir).expanduser() + new_cache_dir = Path(new_cache_dir).expanduser() + for old_blob_path in old_cache_dir.glob("**/blobs/*"): + if old_blob_path.is_file() and not old_blob_path.is_symlink(): + new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) + new_blob_path.parent.mkdir(parents=True, exist_ok=True) + os.replace(old_blob_path, new_blob_path) + try: + os.symlink(new_blob_path, old_blob_path) + except OSError: + logger.warning( + "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." + ) + # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). + + +cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt") +if not os.path.isfile(cache_version_file): + cache_version = 0 +else: + with open(cache_version_file) as f: + try: + cache_version = int(f.read()) + except ValueError: + cache_version = 0 + +if cache_version < 1: + old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 + if old_cache_is_not_empty: + logger.warning( + "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " + "existing cached models. This is a one-time operation, you can interrupt it or run it " + "later by calling `diffusers.utils.hub_utils.move_cache()`." + ) + try: + move_cache() + except Exception as e: + trace = "\n".join(traceback.format_tb(e.__traceback__)) + logger.error( + f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " + "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " + "message and we will do our best to help." + ) + +if cache_version < 1: + try: + os.makedirs(DIFFUSERS_CACHE, exist_ok=True) + with open(cache_version_file, "w") as f: + f.write("1") + except Exception: + logger.warning( + f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " + "the directory exists and can be written to." + ) + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, + commit_hash=None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + # 1. First check if deprecated way of loading from branches is used + if ( + revision in DEPRECATED_REVISION_ARGS + and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) + and version.parse(version.parse(__version__).base_version) >= version.parse("0.20.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: # noqa: E722 + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + try: + # 2. Load model file as usual + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) diff --git a/diffusers/utils/import_utils.py b/diffusers/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f40d939a71194ee4325964838640f142fd1c964 --- /dev/null +++ b/diffusers/utils/import_utils.py @@ -0,0 +1,655 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" +import importlib.util +import operator as op +import os +import sys +from collections import OrderedDict +from typing import Union + +from huggingface_hub.utils import is_jinja_available # noqa: F401 +from packaging import version +from packaging.version import Version, parse + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TORCH is set") + _torch_available = False + + +_tf_version = "N/A" +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + +_jax_version = "N/A" +_flax_version = "N/A" +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + +if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: + _safetensors_available = importlib.util.find_spec("safetensors") is not None + if _safetensors_available: + try: + _safetensors_version = importlib_metadata.version("safetensors") + logger.info(f"Safetensors version {_safetensors_version} available.") + except importlib_metadata.PackageNotFoundError: + _safetensors_available = False +else: + logger.info("Disabling Safetensors because USE_TF is set") + _safetensors_available = False + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + + +_onnxruntime_version = "N/A" +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +if _onnx_available: + candidates = ( + "onnxruntime", + "onnxruntime-gpu", + "ort_nightly_gpu", + "onnxruntime-directml", + "onnxruntime-openvino", + "ort_nightly_directml", + "onnxruntime-rocm", + "onnxruntime-training", + ) + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") + +# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. +# _opencv_available = importlib.util.find_spec("opencv-python") is not None +try: + candidates = ( + "opencv-python", + "opencv-contrib-python", + "opencv-python-headless", + "opencv-contrib-python-headless", + ) + _opencv_version = None + for pkg in candidates: + try: + _opencv_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _opencv_available = _opencv_version is not None + if _opencv_available: + logger.debug(f"Successfully imported cv2 version {_opencv_version}") +except importlib_metadata.PackageNotFoundError: + _opencv_available = False + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported scipy version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + +_librosa_available = importlib.util.find_spec("librosa") is not None +try: + _librosa_version = importlib_metadata.version("librosa") + logger.debug(f"Successfully imported librosa version {_librosa_version}") +except importlib_metadata.PackageNotFoundError: + _librosa_available = False + +_accelerate_available = importlib.util.find_spec("accelerate") is not None +try: + _accelerate_version = importlib_metadata.version("accelerate") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _accelerate_available = False + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + import torch + + if version.Version(torch.__version__) < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None +try: + _k_diffusion_version = importlib_metadata.version("k_diffusion") + logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") +except importlib_metadata.PackageNotFoundError: + _k_diffusion_available = False + +_note_seq_available = importlib.util.find_spec("note_seq") is not None +try: + _note_seq_version = importlib_metadata.version("note_seq") + logger.debug(f"Successfully imported note-seq version {_note_seq_version}") +except importlib_metadata.PackageNotFoundError: + _note_seq_available = False + +_wandb_available = importlib.util.find_spec("wandb") is not None +try: + _wandb_version = importlib_metadata.version("wandb") + logger.debug(f"Successfully imported wandb version {_wandb_version }") +except importlib_metadata.PackageNotFoundError: + _wandb_available = False + +_omegaconf_available = importlib.util.find_spec("omegaconf") is not None +try: + _omegaconf_version = importlib_metadata.version("omegaconf") + logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}") +except importlib_metadata.PackageNotFoundError: + _omegaconf_available = False + +_tensorboard_available = importlib.util.find_spec("tensorboard") +try: + _tensorboard_version = importlib_metadata.version("tensorboard") + logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") +except importlib_metadata.PackageNotFoundError: + _tensorboard_available = False + + +_compel_available = importlib.util.find_spec("compel") +try: + _compel_version = importlib_metadata.version("compel") + logger.debug(f"Successfully imported compel version {_compel_version}") +except importlib_metadata.PackageNotFoundError: + _compel_available = False + + +_ftfy_available = importlib.util.find_spec("ftfy") is not None +try: + _ftfy_version = importlib_metadata.version("ftfy") + logger.debug(f"Successfully imported ftfy version {_ftfy_version}") +except importlib_metadata.PackageNotFoundError: + _ftfy_available = False + + +_bs4_available = importlib.util.find_spec("bs4") is not None +try: + # importlib metadata under different name + _bs4_version = importlib_metadata.version("beautifulsoup4") + logger.debug(f"Successfully imported ftfy version {_bs4_version}") +except importlib_metadata.PackageNotFoundError: + _bs4_available = False + +_torchsde_available = importlib.util.find_spec("torchsde") is not None +try: + _torchsde_version = importlib_metadata.version("torchsde") + logger.debug(f"Successfully imported torchsde version {_torchsde_version}") +except importlib_metadata.PackageNotFoundError: + _torchsde_available = False + +_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None +try: + _invisible_watermark_version = importlib_metadata.version("invisible-watermark") + logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") +except importlib_metadata.PackageNotFoundError: + _invisible_watermark_available = False + + +def is_torch_available(): + return _torch_available + + +def is_safetensors_available(): + return _safetensors_available + + +def is_tf_available(): + return _tf_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_onnx_available(): + return _onnx_available + + +def is_opencv_available(): + return _opencv_available + + +def is_scipy_available(): + return _scipy_available + + +def is_librosa_available(): + return _librosa_available + + +def is_xformers_available(): + return _xformers_available + + +def is_accelerate_available(): + return _accelerate_available + + +def is_k_diffusion_available(): + return _k_diffusion_available + + +def is_note_seq_available(): + return _note_seq_available + + +def is_wandb_available(): + return _wandb_available + + +def is_omegaconf_available(): + return _omegaconf_available + + +def is_tensorboard_available(): + return _tensorboard_available + + +def is_compel_available(): + return _compel_available + + +def is_ftfy_available(): + return _ftfy_available + + +def is_bs4_available(): + return _bs4_available + + +def is_torchsde_available(): + return _torchsde_available + + +def is_invisible_watermark_available(): + return _invisible_watermark_available + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip +install opencv-python` +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +LIBROSA_IMPORT_ERROR = """ +{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the +installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + +# docstyle-ignore +K_DIFFUSION_IMPORT_ERROR = """ +{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip +install k-diffusion` +""" + +# docstyle-ignore +NOTE_SEQ_IMPORT_ERROR = """ +{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip +install note-seq` +""" + +# docstyle-ignore +WANDB_IMPORT_ERROR = """ +{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip +install wandb` +""" + +# docstyle-ignore +OMEGACONF_IMPORT_ERROR = """ +{0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip +install omegaconf` +""" + +# docstyle-ignore +TENSORBOARD_IMPORT_ERROR = """ +{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip +install tensorboard` +""" + + +# docstyle-ignore +COMPEL_IMPORT_ERROR = """ +{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` +""" + +# docstyle-ignore +BS4_IMPORT_ERROR = """ +{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: +`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FTFY_IMPORT_ERROR = """ +{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the +installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TORCHSDE_IMPORT_ERROR = """ +{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` +""" + +# docstyle-ignore +INVISIBLE_WATERMARK_IMPORT_ERROR = """ +{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + "UnCLIPPipeline", + ] and is_transformers_version("<", "4.25.0"): + raise ImportError( + f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" + " --upgrade transformers \n```" + ) + + if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( + "<", "4.26.0" + ): + raise ImportError( + f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" + " --upgrade transformers \n```" + ) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) + + +def is_accelerate_version(operation: str, version: str): + """ + Args: + Compares the current Accelerate version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _accelerate_available: + return False + return compare_versions(parse(_accelerate_version), operation, version) + + +def is_k_diffusion_version(operation: str, version: str): + """ + Args: + Compares the current k-diffusion version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _k_diffusion_available: + return False + return compare_versions(parse(_k_diffusion_version), operation, version) + + +class OptionalDependencyNotAvailable(BaseException): + """An error indicating that an optional dependency of Diffusers was not found in the environment.""" diff --git a/diffusers/utils/logging.py b/diffusers/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccc57cd69d57e9bd999e35320cb98416f000522 --- /dev/null +++ b/diffusers/utils/logging.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2023 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) +from typing import Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Diffusers' root logger as an `int`. + + Returns: + `int`: + Logging level integers which can be one of: + + - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `40`: `diffusers.logging.ERROR` + - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `20`: `diffusers.logging.INFO` + - `10`: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Diffusers' root logger. + + Args: + verbosity (`int`): + Logging level which can be one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the 🤗 Diffusers' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Diffusers' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for 🤗 Diffusers' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/diffusers/utils/model_card_template.md b/diffusers/utils/model_card_template.md new file mode 100644 index 0000000000000000000000000000000000000000..f19c85b0fcf2f7b07e9c3f950a9657b3f2053f21 --- /dev/null +++ b/diffusers/utils/model_card_template.md @@ -0,0 +1,50 @@ +--- +{{ card_data }} +--- + + + +# {{ model_name | default("Diffusion Model") }} + +## Model description + +This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library +on the `{{ dataset_name }}` dataset. + +## Intended uses & limitations + +#### How to use + +```python +# TODO: add an example code snippet for running this diffusion pipeline +``` + +#### Limitations and bias + +[TODO: provide examples of latent issues and potential remediations] + +## Training data + +[TODO: describe the data used to train the model] + +### Training hyperparameters + +The following hyperparameters were used during training: +- learning_rate: {{ learning_rate }} +- train_batch_size: {{ train_batch_size }} +- eval_batch_size: {{ eval_batch_size }} +- gradient_accumulation_steps: {{ gradient_accumulation_steps }} +- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} +- lr_scheduler: {{ lr_scheduler }} +- lr_warmup_steps: {{ lr_warmup_steps }} +- ema_inv_gamma: {{ ema_inv_gamma }} +- ema_inv_gamma: {{ ema_power }} +- ema_inv_gamma: {{ ema_max_decay }} +- mixed_precision: {{ mixed_precision }} + +### Training results + +📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) + + diff --git a/diffusers/utils/outputs.py b/diffusers/utils/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..37b11561d1e1ee5d5cb40c7630b132e1f451c5b0 --- /dev/null +++ b/diffusers/utils/outputs.py @@ -0,0 +1,108 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +from collections import OrderedDict +from dataclasses import fields +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available + + +def is_tensor(x): + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + Python dictionary. + + + + You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + first. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) diff --git a/diffusers/utils/pil_utils.py b/diffusers/utils/pil_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15b97c73dcb7f85b22fcae95c641dde0123b5f05 --- /dev/null +++ b/diffusers/utils/pil_utils.py @@ -0,0 +1,48 @@ +import PIL.Image +import PIL.ImageOps +from packaging import version +from PIL import Image + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } + + +def pt_to_pil(images): + """ + Convert a torch image to a PIL image. + """ + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = numpy_to_pil(images) + return images + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images diff --git a/diffusers/utils/testing_utils.py b/diffusers/utils/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64eb3ac925e9240d30766547880c5dea2e0aeb43 --- /dev/null +++ b/diffusers/utils/testing_utils.py @@ -0,0 +1,602 @@ +import inspect +import logging +import multiprocessing +import os +import random +import re +import tempfile +import unittest +import urllib.parse +from distutils.util import strtobool +from io import BytesIO, StringIO +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import PIL.Image +import PIL.ImageOps +import requests +from packaging import version + +from .import_utils import ( + BACKENDS_MAPPING, + is_compel_available, + is_flax_available, + is_note_seq_available, + is_onnx_available, + is_opencv_available, + is_torch_available, + is_torch_version, + is_torchsde_available, +) +from .logging import get_logger + + +global_rng = random.Random() + +logger = get_logger(__name__) + +if is_torch_available(): + import torch + + if "DIFFUSERS_TEST_DEVICE" in os.environ: + torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] + + available_backends = ["cuda", "cpu", "mps"] + if torch_device not in available_backends: + raise ValueError( + f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:" + f" {available_backends}" + ) + logger.info(f"torch_device overrode to {torch_device}") + else: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + is_torch_higher_equal_than_1_12 = version.parse( + version.parse(torch.__version__).base_version + ) >= version.parse("1.12") + + if is_torch_higher_equal_than_1_12: + # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details + mps_backend_registered = hasattr(torch.backends, "mps") + torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device + + +def torch_all_close(a, b, *args, **kwargs): + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + if not torch.allclose(a, b, *args, **kwargs): + assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." + return True + + +def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"): + test_name = os.environ.get("PYTEST_CURRENT_TEST") + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") + # format is usually: + # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) + output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") + test_file, test_class, test_fn = test_name.split("::") + test_fn = test_fn.split()[0] + with open(filename, "a") as f: + print(";".join([test_file, test_class, test_fn, output_str]), file=f) + + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return os.path.join(tests_dir, append_path) + else: + return tests_dir + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def nightly(test_case): + """ + Decorator marking a test that runs nightly in the diffusers CI. + + Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_torch_2(test_case): + """ + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. + """ + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( + test_case + ) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( + test_case + ) + + +def skip_mps(test_case): + """Decorator marking a test to skip if torch_device is 'mps'""" + return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) + + +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + +def require_compel(test_case): + """ + Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when + the library is not installed. + """ + return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case) + + +def require_onnxruntime(test_case): + """ + Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. + """ + return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) + + +def require_note_seq(test_case): + """ + Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. + """ + return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) + + +def require_torchsde(test_case): + """ + Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. + """ + return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) + + +def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: + if isinstance(arry, str): + # local_path = "/home/patrick_huggingface_co/" + if local_path is not None: + # local_path can be passed to correct images of tests + return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]])) + elif arry.startswith("http://") or arry.startswith("https://"): + response = requests.get(arry) + response.raise_for_status() + arry = np.load(BytesIO(response.content)) + elif os.path.isfile(arry): + arry = np.load(arry) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" + ) + elif isinstance(arry, np.ndarray): + pass + else: + raise ValueError( + "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" + " ndarray." + ) + + return arry + + +def load_pt(url: str): + response = requests.get(url) + response.raise_for_status() + arry = torch.load(BytesIO(response.content)) + return arry + + +def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def preprocess_image(image: PIL.Image, batch_size: int): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str: + if output_gif_path is None: + output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name + + image[0].save( + output_gif_path, + save_all=True, + append_images=image[1:], + optimize=False, + duration=100, + loop=0, + ) + return output_gif_path + + +def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, c = video_frames[0].shape + video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video_frames)): + img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return output_video_path + + +def load_hf_numpy(path) -> np.ndarray: + if not path.startswith("http://") or path.startswith("https://"): + path = os.path.join( + "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) + ) + + return load_numpy(path) + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should + pytest do internal changes - also it calls default internal methods of terminalreporter which + can be hijacked by various `pytest-` plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = "reports" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{id}_{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + with open(report_files["passes"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') + + +class CaptureLogger: + """ + Args: + Context manager to capture `logging` streams + logger: 'logging` logger object + Returns: + The captured output is available via `self.out` + Example: + ```python + >>> from diffusers import logging + >>> from diffusers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" + + +def enable_full_determinism(): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cuda.matmul.allow_tf32 = False + + +def disable_full_determinism(): + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" + torch.use_deterministic_algorithms(False) diff --git a/diffusers/utils/torch_utils.py b/diffusers/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f64bce25e78d5212696f4b06b767d338599670a --- /dev/null +++ b/diffusers/utils/torch_utils.py @@ -0,0 +1,84 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch utilities: Utilities related to PyTorch +""" +from typing import List, Optional, Tuple, Union + +from . import logging +from .import_utils import is_torch_available, is_torch_version + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +try: + from torch._dynamo import allow_in_graph as maybe_allow_in_graph +except (ImportError, ModuleNotFoundError): + + def maybe_allow_in_graph(cls): + return cls + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +def is_compiled_module(module): + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) diff --git a/docs/annotator.md b/docs/annotator.md new file mode 100644 index 0000000000000000000000000000000000000000..5111678ff11407da19a920e0c500ee0480bbb155 --- /dev/null +++ b/docs/annotator.md @@ -0,0 +1,49 @@ +# Automatic Annotations + +We provide gradio examples to obtain annotations that are aligned to our pretrained production-ready models. + +Just run + + python gradio_annotator.py + +Since everyone has different habit to organize their datasets, we do not hard code any scripts for batch processing. But "gradio_annotator.py" is written in a super readable way, and modifying it to annotate your images should be easy. + +In the gradio UI of "gradio_annotator.py" we have the following interfaces: + +### Canny Edge + +Be careful about "black edge and white background" or "white edge and black background". + +![p](../github_page/a1.png) + +### HED Edge + +Be careful about "black edge and white background" or "white edge and black background". + +![p](../github_page/a2.png) + +### MLSD Edge + +Be careful about "black edge and white background" or "white edge and black background". + +![p](../github_page/a3.png) + +### MIDAS Depth and Normal + +Be careful about RGB or BGR in normal maps. + +![p](../github_page/a4.png) + +### Openpose + +Be careful about RGB or BGR in pose maps. + +For our production-ready model, the hand pose option is turned off. + +![p](../github_page/a5.png) + +### Uniformer Segmentation + +Be careful about RGB or BGR in segmentation maps. + +![p](../github_page/a6.png) diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 0000000000000000000000000000000000000000..07afd7aeacb51cac4c8bac3b601fe23a2842c4d3 --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,21 @@ +# FAQs + +**Q:** If the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works? + +**A:** This is wrong. Let us consider a very simple + +$$y=wx+b$$ + +and we have + +$$\partial y/\partial w=x, \partial y/\partial x=w, \partial y/\partial b=1$$ + +and if $w=0$ and $x \neq 0$, then + +$$\partial y/\partial w \neq 0, \partial y/\partial x=0, \partial y/\partial b\neq 0$$ + +which means as long as $x \neq 0$, one gradient descent iteration will make $w$ non-zero. Then + +$$\partial y/\partial x\neq 0$$ + +so that the zero convolutions will progressively become a common conv layer with non-zero weights. diff --git a/docs/low_vram.md b/docs/low_vram.md new file mode 100644 index 0000000000000000000000000000000000000000..784964c78d5074ed9d318456d7c35f30a81f04ed --- /dev/null +++ b/docs/low_vram.md @@ -0,0 +1,15 @@ +# Enable Low VRAM Mode + +If you are using 8GB GPU card (or if you want larger batch size), please open "config.py", and then set + +```python +save_memory = True +``` + +This feature is still being tested - not all graphics cards are guaranteed to succeed. + +But it should be neat as I can diffuse at a batch size of 12 now. + +(prompt "man") + +![p](../github_page/ram12.jpg) diff --git a/docs/train.md b/docs/train.md new file mode 100644 index 0000000000000000000000000000000000000000..fa773925e28c4100df4bf74f0536d432554db806 --- /dev/null +++ b/docs/train.md @@ -0,0 +1,276 @@ +# Train a ControlNet to Control SD + +You are here because you want to control SD in your own way, maybe you have an idea for your perfect research project, and you will annotate some data or have already annotated your own dataset automatically or manually. Herein, the control can be anything that can be converted to images, such as edges, keypoints, segments, etc. + +Before moving on to your own dataset, we highly recommend to first try the toy dataset, Fill50K, as a sanity check. This will help you get a "feeling" for the training. You will know how long it will take for the model to converge and whether your device will be able to complete the training in an acceptable amount of time. And what it "feels" like when the model converges. + +We hope that after you read this page, you will find that training a ControlNet is as easy as (or easier than) training a pix2pix. + +## Step 0 - Design your control + +Let us take a look at a very simple task to control SD to fill color in circles. + +![p](../github_page/t1.png) + +This is simple: we want to control SD to fill a circle with colors, and the prompt contains some description of our target. + +Stable diffusion is trained on billions of images, and it already knows what is "cyan", what is "circle", what is "pink", and what is "background". + +But it does not know the meaning of that "Control Image (Source Image)". Our target is to let it know. + +## Step 1 - Get a dataset + +Just download the Fill50K dataset from [our huggingface page](https://huggingface.co/lllyasviel/ControlNet) (training/fill50k.zip, the file is only 200M!). Make sure that the data is decompressed as + + ControlNet/training/fill50k/prompt.json + ControlNet/training/fill50k/source/X.png + ControlNet/training/fill50k/target/X.png + +In the folder "fill50k/source", you will have 50k images of circle lines. + +![p](../github_page/t2.png) + +In the folder "fill50k/target", you will have 50k images of filled circles. + +![p](../github_page/t3.png) + +In the "fill50k/prompt.json", you will have their filenames and prompts. Each prompt is like "a balabala color circle in some other color background." + +![p](../github_page/t4.png) + +## Step 2 - Load the dataset + +Then you need to write a simple script to read this dataset for pytorch. (In fact we have written it for you in "tutorial_dataset.py".) + +```python +import json +import cv2 +import numpy as np + +from torch.utils.data import Dataset + + +class MyDataset(Dataset): + def __init__(self): + self.data = [] + with open('./training/fill50k/prompt.json', 'rt') as f: + for line in f: + self.data.append(json.loads(line)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + source_filename = item['source'] + target_filename = item['target'] + prompt = item['prompt'] + + source = cv2.imread('./training/fill50k/' + source_filename) + target = cv2.imread('./training/fill50k/' + target_filename) + + # Do not forget that OpenCV read images in BGR order. + source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) + + # Normalize source images to [0, 1]. + source = source.astype(np.float32) / 255.0 + + # Normalize target images to [-1, 1]. + target = (target.astype(np.float32) / 127.5) - 1.0 + + return dict(jpg=target, txt=prompt, hint=source) + +``` + +This will make your dataset into an array-like object in python. You can test this dataset simply by accessing the array, like this + +```python +from tutorial_dataset import MyDataset + +dataset = MyDataset() +print(len(dataset)) + +item = dataset[1234] +jpg = item['jpg'] +txt = item['txt'] +hint = item['hint'] +print(txt) +print(jpg.shape) +print(hint.shape) + +``` + +The outputs of this simple test on my machine are + + 50000 + burly wood circle with orange background + (512, 512, 3) + (512, 512, 3) + +And this code is in "tutorial_dataset_test.py". + +In this way, the dataset is an array-like object with 50000 items. Each item is a dict with three entry "jpg", "txt", and "hint". The "jpg" is the target image, the "hint" is the control image, and the "txt" is the prompt. + +Do not ask us why we use these three names - this is related to the dark history of a library called LDM. + +## Step 3 - What SD model do you want to control? + +Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). + +(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2.) + +Then you need to attach a control net to the SD model. The architecture is + +![img](../github_page/sd.png) + +Note that all weights inside the ControlNet are also copied from SD so that no layer is trained from scratch, and you are still finetuning the entire model. + +We provide a simple script for you to achieve this easily. If your SD filename is "./models/v1-5-pruned.ckpt" and you want the script to save the processed model (SD+ControlNet) at location "./models/control_sd15_ini.ckpt", you can just run: + + python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt + +Or if you are using SD2: + + python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt + +You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path". + +This is the correct output from my machine: + +![img](../github_page/t5.png) + +## Step 4 - Train! + +Happy! We finally come to the most exciting part: training! + +The training code in "tutorial_train.py" is actually surprisingly simple: + +```python +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from tutorial_dataset import MyDataset +from cldm.logger import ImageLogger +from cldm.model import create_model, load_state_dict + + +# Configs +resume_path = './models/control_sd15_ini.ckpt' +batch_size = 4 +logger_freq = 300 +learning_rate = 1e-5 +sd_locked = True +only_mid_control = False + + +# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict(resume_path, location='cpu')) +model.learning_rate = learning_rate +model.sd_locked = sd_locked +model.only_mid_control = only_mid_control + + +# Misc +dataset = MyDataset() +dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True) +logger = ImageLogger(batch_frequency=logger_freq) +trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger]) + + +# Train! +trainer.fit(model, dataloader) + +``` +(or "tutorial_train_sd21.py" if you are using SD2) + +Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short. + +Now, you may take a look at [Pytorch Lightning Official DOC](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#trainer) to find out how to enable many useful features like gradient accumulation, multiple GPU training, accelerated dataset loading, flexible checkpoint saving, etc. All these only need about one line of code. Great! + +Note that if you find OOM, perhaps you need to enable [Low VRAM mode](low_vram.md), and perhaps you also need to use smaller batch size and gradient accumulation. Or you may also want to use some “advanced” tricks like sliced attention or xformers. For example: + +```python +# Configs +batch_size = 1 + +# Misc +trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger], accumulate_grad_batches=4) # But this will be 4x slower +``` + +Note that training with 8 GB laptop GPU is challenging. We will need some GPU memory optimization at least as good as automatic1111’s UI. This may require expert modifications to the code. + +### Screenshots + +The training is fast. After 4000 steps (batch size 4, learning rate 1e-5, about 50 minutes on PCIE 40G), the results on my machine (in an output folder "image_log") is + +Control: + +![img](../github_page/t/ip.png) + +Prompt: + +![img](../github_page/t/t.png) + +Prediction: + +![img](../github_page/t/op.png) + +Ground Truth: + +![img](../github_page/t/gt.png) + +Note that the SD's capability is preserved. Even training on this super aligned dataset, it still draws some random textures and those snow decorations. (Besides, note that the ground truth looks a bit modified because it is converted from SD's latent image.) + +Larger batch size and longer training will further improve this. Adequate training will make the filling perfect. + +Of course, training SD to fill circles is meaningless, but this is a successful beginning of your story. + +Let us work together to control large models more and more. + +## Other options + +Beyond standard things, we also provide two important parameters "sd_locked" and "only_mid_control" that you need to know. + +### only_mid_control + +By default, only_mid_control is False. When it is True, you will train the below architecture. + +![img](../github_page/t6.png) + +This can be helpful when your computation power is limited and want to speed up the training, or when you want to facilitate the "global" context learning. Note that sometimes you may pause training, set it to True, resume training, and pause again, and set it again, and resume again. + +If your computation device is good, perhaps you do not need this. But I also know some artists are willing to train a model on their laptop for a month - in that case, perhaps this option can be useful. + +### sd_locked + +By default, sd_locked is True. When it is False, you will train the below architecture. + +![img](../github_page/t7.png) + +This will unlock some layers in SD and you will train them as a whole. + +This option is DANGEROUS! If your dataset is not good enough, this may downgrade the capability of your SD model. + +However, this option is also very useful when you are training on images with some specific style, or when you are training with special datasets (like medical dataset with X-ray images or geographic datasets with lots of Google Maps). You can understand this as simultaneously training the ControlNet and something like a DreamBooth. + +Also, if your dataset is large, you may want to end the training with a few thousands of steps with those layer unlocked. This usually improve the "problem-specific" solutions a little. You may try it yourself to feel the difference. + +Also, if you unlock some original layers, you may want a lower learning rate, like 2e-6. + +## More Consideration: Sudden Converge Phenomenon and Gradient Accumulation + +![img](../github_page/ex1.jpg) + +Because we use zero convolutions, the SD should always be able to predict meaningful images. (If it cannot, the training has already failed.) + +You will always find that at some iterations, the model "suddenly" be able to fit some training conditions. This means that you will get a basically usable model at about 3k to 7k steps (future training will improve it, but that model after the first "sudden converge" should be basically functional). + +Note that 3k to 7k steps is not very large, and you should consider larger batch size rather than more training steps. If you can observe the "sudden converge" at 3k step using batch size 4, then, rather than train it with 300k further steps, a better idea is to use 100× gradient accumulation to re-train that 3k steps with 100× batch size. Note that perhaps we should not do this *too* extremely (perhaps 100x accumulation is too extreme), but you should consider that, since "sudden converge" will *always* happen at that certain point, getting a better converge is more important. + +Because that "sudden converge" always happens, lets say "sudden converge" will happen at 3k step and our money can optimize 90k step, then we have two options: (1) train 3k steps, sudden converge, then train 87k steps. (2) 30x gradient accumulation, train 3k steps (90k real computation steps), then sudden converge. + +In my experiments, (2) is usually better than (1). However, in real cases, perhaps you may need to balance the steps before and after the "sudden converge" on your own to find a balance. The training after "sudden converge" is also important. + +But usually, if your logic batch size is already bigger than 256, then further extending the batch size is not very meaningful. In that case, perhaps a better idea is to train more steps. I tried some "common" logic batch size at 64 or 96 or 128 (by gradient accumulation), it seems that many complicated conditions can be solved very well already. diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91463f0fb17145d7649e298819811ac9a21d6b93 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,35 @@ +name: control +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - gradio==3.16.2 + - albumentations==1.3.0 + - opencv-contrib-python==4.3.0.36 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.5.0 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.3.0 + - transformers==4.19.2 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.6.0 + - timm==0.6.12 + - addict==2.4.0 + - yapf==0.32.0 + - prettytable==3.6.0 + - safetensors==0.2.7 + - basicsr==1.4.2 diff --git a/font/DejaVuSans.ttf b/font/DejaVuSans.ttf new file mode 100644 index 0000000000000000000000000000000000000000..e5f7eecce43be41ff0703ed99e1553029b849f14 Binary files /dev/null and b/font/DejaVuSans.ttf differ diff --git a/github_page/a1.png b/github_page/a1.png new file mode 100644 index 0000000000000000000000000000000000000000..cb3312c9db173924d4bf99d13fc8d06086f6cb79 Binary files /dev/null and b/github_page/a1.png differ diff --git a/github_page/a2.png b/github_page/a2.png new file mode 100644 index 0000000000000000000000000000000000000000..efa1ee82e5c76ec6b5fe9f19fdddbe48575c2924 Binary files /dev/null and b/github_page/a2.png differ diff --git a/github_page/a3.png b/github_page/a3.png new file mode 100644 index 0000000000000000000000000000000000000000..521a093b08d587b8dec4f7f3d7fbd2db32ee2c43 Binary files /dev/null and b/github_page/a3.png differ diff --git a/github_page/a4.png b/github_page/a4.png new file mode 100644 index 0000000000000000000000000000000000000000..d9fc6ee817d79370ddce9d7e57a7f9a63010df73 Binary files /dev/null and b/github_page/a4.png differ diff --git a/github_page/a5.png b/github_page/a5.png new file mode 100644 index 0000000000000000000000000000000000000000..1f4fdddb30a3dedb1b2f194ca8009c2704313cd2 Binary files /dev/null and b/github_page/a5.png differ diff --git a/github_page/a6.png b/github_page/a6.png new file mode 100644 index 0000000000000000000000000000000000000000..75555245336b3ba6de7d6abe554008e9da199b40 Binary files /dev/null and b/github_page/a6.png differ diff --git a/github_page/control.pdf b/github_page/control.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a4eb41d88516a815097da3a8e2e0b99d0dca6794 --- /dev/null +++ b/github_page/control.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed468ba5e8e20f250a6949643398c18987dd43e48ebb5f1d6b81a134ec10f90e +size 21869744 diff --git a/github_page/ex1.jpg b/github_page/ex1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..89df766784b0e62481ca90d713db1574f8c7990f Binary files /dev/null and b/github_page/ex1.jpg differ diff --git a/github_page/he.png b/github_page/he.png new file mode 100644 index 0000000000000000000000000000000000000000..7af7f18ecbedd3449f2dd4f92c825bf2a2c9c8f4 Binary files /dev/null and b/github_page/he.png differ diff --git a/github_page/multi.png b/github_page/multi.png new file mode 100644 index 0000000000000000000000000000000000000000..b7e8c6103c82d11137ec646588fa899dc66c2ea4 Binary files /dev/null and b/github_page/multi.png differ diff --git a/github_page/multi2.png b/github_page/multi2.png new file mode 100644 index 0000000000000000000000000000000000000000..10d6b17655a828ec1165feca89193d9cb5566057 Binary files /dev/null and b/github_page/multi2.png differ diff --git a/github_page/p1.png b/github_page/p1.png new file mode 100644 index 0000000000000000000000000000000000000000..c9f8cc8e27977f146ea44b6d2fe1757acaa4e091 Binary files /dev/null and b/github_page/p1.png differ diff --git a/github_page/p10.png b/github_page/p10.png new file mode 100644 index 0000000000000000000000000000000000000000..ca8d5f6175bc6bcb7f251b056b08858813470dd1 --- /dev/null +++ b/github_page/p10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea30b36f988db5e787e3d061a1745bc94a2991bb76e38325459e79326ae449be +size 1558223 diff --git a/github_page/p11.png b/github_page/p11.png new file mode 100644 index 0000000000000000000000000000000000000000..eb0ecfc2f81a4a65ce3422c5c968e2222bed8824 --- /dev/null +++ b/github_page/p11.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bba5160851111c6758fd0e34fa543369deb443ed6530b93b84d547384a19aef +size 1025272 diff --git a/github_page/p12.png b/github_page/p12.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd4bc0005eff38ec6d90e9b8dc36eba975bffba --- /dev/null +++ b/github_page/p12.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a609ddf0e88c4efbb5ba324935a521f57de94a6ebbc2b24353a110a2cdfbef3f +size 1037234 diff --git a/github_page/p13.png b/github_page/p13.png new file mode 100644 index 0000000000000000000000000000000000000000..8ac6bc4400bf71c7d709bfcda0e049f8a4f13115 --- /dev/null +++ b/github_page/p13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5381027bb60cecb3d20a5a33c614c2bbf313090384e0f3b20117ac58d926062 +size 1140361 diff --git a/github_page/p14.png b/github_page/p14.png new file mode 100644 index 0000000000000000000000000000000000000000..88c1834789cfe1eb67acb395a227e330ff946217 --- /dev/null +++ b/github_page/p14.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8db031dce5ce4be02cc4d3f82fedd5fd4ac1a595c7fc8c9ba36978d475315aad +size 1398887 diff --git a/github_page/p15.png b/github_page/p15.png new file mode 100644 index 0000000000000000000000000000000000000000..1e4b8e4152a326f2bf6cff4a7f95611007a4c5a0 --- /dev/null +++ b/github_page/p15.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:748a7879c839e329ba0921aa4015753c5879811cee2976dcb6636c42fee674db +size 1366530 diff --git a/github_page/p16b.png b/github_page/p16b.png new file mode 100644 index 0000000000000000000000000000000000000000..6556de32514b3c5c57c238ee59577071450e210a --- /dev/null +++ b/github_page/p16b.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d07b59462d44e9d442a208bb0d457cf7e5ba41059fd1ebd5ca30d23e982e430 +size 1399940 diff --git a/github_page/p17.png b/github_page/p17.png new file mode 100644 index 0000000000000000000000000000000000000000..de60192873a6e7c1963c7350121b56260390c405 --- /dev/null +++ b/github_page/p17.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b39ffbc341b431121ac82ae8cd235afbcd7c5e30f57eef58f00d4bb8dcbda6af +size 1446761 diff --git a/github_page/p18.png b/github_page/p18.png new file mode 100644 index 0000000000000000000000000000000000000000..3278afee8800111bec6c246515de76086b8864cd --- /dev/null +++ b/github_page/p18.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:063bea46c7d9eb7757d5b5674e0cb84f054c96472a38553d4358ad8ab86358b8 +size 1057593 diff --git a/github_page/p19.png b/github_page/p19.png new file mode 100644 index 0000000000000000000000000000000000000000..3648956ee0370a748e1f2adfdc62b7182a8a568a --- /dev/null +++ b/github_page/p19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:502a2242ad1703e47a5be9398fbd301d3c287f21b85d92f5be314477673b62ac +size 1113879 diff --git a/github_page/p2.png b/github_page/p2.png new file mode 100644 index 0000000000000000000000000000000000000000..de2c6353ff834c82bc63ebbfacc8954d9e8ebec7 --- /dev/null +++ b/github_page/p2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:150a4298489a26844f44b51e365b07c454102f83f571133326c1121b1a939f40 +size 1318561 diff --git a/github_page/p20.png b/github_page/p20.png new file mode 100644 index 0000000000000000000000000000000000000000..13b9c645331b0ee8d0f9a69122f26074caa6246e --- /dev/null +++ b/github_page/p20.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b0511aaba8b1dfcfcc87a684c0c1f30339d88519aba475be356d0dc30090deb +size 1341061 diff --git a/github_page/p21.png b/github_page/p21.png new file mode 100644 index 0000000000000000000000000000000000000000..14bdbd5ff8e4c159ebb54b787306f3147948346d --- /dev/null +++ b/github_page/p21.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:709cf6a47616a292c72a03fe046d6158a3004cd8f972223f24a867328908a6a6 +size 1810899 diff --git a/github_page/p3.png b/github_page/p3.png new file mode 100644 index 0000000000000000000000000000000000000000..927efb593523cb1554f1c736ae9e539b5cd6b130 --- /dev/null +++ b/github_page/p3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5de3f7e26f159237f15439c7dc4b13b85f8f2214b4ec5d3577018bc4272e27b +size 1203738 diff --git a/github_page/p4.png b/github_page/p4.png new file mode 100644 index 0000000000000000000000000000000000000000..6d5e845b046771398f0d2fa3d8d475f5df92fcfc --- /dev/null +++ b/github_page/p4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a038d132732106731ce899e6d924a34bbcf19d89318b768548683478c1a42590 +size 1159029 diff --git a/github_page/p5.png b/github_page/p5.png new file mode 100644 index 0000000000000000000000000000000000000000..aaa151beb0abfbd87b46b0d663ada50ae4af1f81 --- /dev/null +++ b/github_page/p5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0584509963f00842a98a42b4edcd9c84c0310753e6e8861979f87d7f0f27fc2 +size 1268753 diff --git a/github_page/p6.png b/github_page/p6.png new file mode 100644 index 0000000000000000000000000000000000000000..77eb22c8e7762ede1786196bcb7634fc522ed128 --- /dev/null +++ b/github_page/p6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4258437186ecc95d4ec63728afdc43af902b1afaf3719e5cccba2fdfc755c0c4 +size 1454815 diff --git a/github_page/p7.png b/github_page/p7.png new file mode 100644 index 0000000000000000000000000000000000000000..3241d45777fe8d6416cb97e051e74a203d34d0f2 --- /dev/null +++ b/github_page/p7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:557e0f50015a0a681a29eb7d9747a1a251c87196eb4bdfa0bfb4cbede97d3123 +size 1671023 diff --git a/github_page/p8.png b/github_page/p8.png new file mode 100644 index 0000000000000000000000000000000000000000..ab72d0dde930fbd3b173f3efc3e4a64532e22693 --- /dev/null +++ b/github_page/p8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d53b77b5dc70ddc5adc0a3ffa354cefdfd422637bf6c6c48e99f7f2c2598851 +size 1555201 diff --git a/github_page/p9.png b/github_page/p9.png new file mode 100644 index 0000000000000000000000000000000000000000..03e1233f888a74f9edbd440136174cad2d2909d2 --- /dev/null +++ b/github_page/p9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcf3b888b15e2d469b204659e9bcb9d8c12b7d5ca9be4411fb3b25c2e958c074 +size 1428057 diff --git a/github_page/ram12.jpg b/github_page/ram12.jpg new file mode 100644 index 0000000000000000000000000000000000000000..208026eda25f97d09dcf80a020a16e5de1e77230 Binary files /dev/null and b/github_page/ram12.jpg differ diff --git a/github_page/sd.png b/github_page/sd.png new file mode 100644 index 0000000000000000000000000000000000000000..2eb3a8567e5278cae0d7a29e65e965d09cba4797 Binary files /dev/null and b/github_page/sd.png differ diff --git a/github_page/t/gt.png b/github_page/t/gt.png new file mode 100644 index 0000000000000000000000000000000000000000..df81bcec5be141763af561dded39e766a7de137d Binary files /dev/null and b/github_page/t/gt.png differ diff --git a/github_page/t/ip.png b/github_page/t/ip.png new file mode 100644 index 0000000000000000000000000000000000000000..cc980a22d1e14903d9bd556b9a6c21fbfecbcc63 Binary files /dev/null and b/github_page/t/ip.png differ diff --git a/github_page/t/op.png b/github_page/t/op.png new file mode 100644 index 0000000000000000000000000000000000000000..0e51e87cec204f1697e755c8c644f53c7f3dd8a4 --- /dev/null +++ b/github_page/t/op.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9833659efcb7353b150896028bf70b6123ca132b658ee65d74d2b40e233aa6b +size 1644715 diff --git a/github_page/t/t.png b/github_page/t/t.png new file mode 100644 index 0000000000000000000000000000000000000000..404fad5d89abb07da21686e27b5cd2603381fdfc Binary files /dev/null and b/github_page/t/t.png differ diff --git a/github_page/t1.png b/github_page/t1.png new file mode 100644 index 0000000000000000000000000000000000000000..4e6387a279cc48dce605c3732a08bb30360b6116 Binary files /dev/null and b/github_page/t1.png differ diff --git a/github_page/t2.png b/github_page/t2.png new file mode 100644 index 0000000000000000000000000000000000000000..a914399475f0d5ba585e0b2c4e565a4b6a4fe42a Binary files /dev/null and b/github_page/t2.png differ diff --git a/github_page/t3.png b/github_page/t3.png new file mode 100644 index 0000000000000000000000000000000000000000..5d27510f96216187b6c08277dacf9cc5b7ab6094 Binary files /dev/null and b/github_page/t3.png differ diff --git a/github_page/t4.png b/github_page/t4.png new file mode 100644 index 0000000000000000000000000000000000000000..accc7f76d921a64360c98c59f575332dab2202ea Binary files /dev/null and b/github_page/t4.png differ diff --git a/github_page/t5.png b/github_page/t5.png new file mode 100644 index 0000000000000000000000000000000000000000..714af4b811a7934ca6eda46788aff660a7569bd5 Binary files /dev/null and b/github_page/t5.png differ diff --git a/github_page/t6.png b/github_page/t6.png new file mode 100644 index 0000000000000000000000000000000000000000..db8c1d4ce649e130b4fb83c6412efa15ecc71b42 Binary files /dev/null and b/github_page/t6.png differ diff --git a/github_page/t7.png b/github_page/t7.png new file mode 100644 index 0000000000000000000000000000000000000000..43a3dac046c25fd18ee1ce89389a16bae4633b3d Binary files /dev/null and b/github_page/t7.png differ diff --git a/github_page/uc1.png b/github_page/uc1.png new file mode 100644 index 0000000000000000000000000000000000000000..e3b874e0497103be64127586aeb89a9f4b261b6d Binary files /dev/null and b/github_page/uc1.png differ diff --git a/github_page/uc2a.png b/github_page/uc2a.png new file mode 100644 index 0000000000000000000000000000000000000000..3a6861226a641b1952445587bf0cec4441fa8167 --- /dev/null +++ b/github_page/uc2a.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96d60afb5cb0975f9b21e7b88aa506c92d2a0fcfcac82aeb58f94d3624355c64 +size 1100385 diff --git a/github_page/uc2b.png b/github_page/uc2b.png new file mode 100644 index 0000000000000000000000000000000000000000..715008a3e93f093d655e4891bf1ec7a4d24d586e --- /dev/null +++ b/github_page/uc2b.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73a23761f4843978657c65d62380fec01641f590d69f36dcb73267e9a28f4544 +size 1555138 diff --git a/github_page/uc3.png b/github_page/uc3.png new file mode 100644 index 0000000000000000000000000000000000000000..c177def0fc8d375d21d7e355a2fe378009bb53e5 --- /dev/null +++ b/github_page/uc3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56d2ce31584612d38a6bb94af473f261fc29a6013fa9da60f77fa52857f34e53 +size 1333667 diff --git a/github_page/uc4.png b/github_page/uc4.png new file mode 100644 index 0000000000000000000000000000000000000000..b662265088fcbc69affd9ca31dcec5b88499e52e --- /dev/null +++ b/github_page/uc4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a6b6f035ddfcef6b96c3e869113a436f11affa8d3cfef2ff7c7050a856b0fe4 +size 1294259 diff --git a/github_page/uc6.png b/github_page/uc6.png new file mode 100644 index 0000000000000000000000000000000000000000..e3009ffc3309c21d987b392f91f1eda0038349a2 --- /dev/null +++ b/github_page/uc6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:905fd75afb29e6148e3f962e55620483a17120d6999d0b2bcae092511465157e +size 1830853 diff --git a/github_page/uci1.png b/github_page/uci1.png new file mode 100644 index 0000000000000000000000000000000000000000..5e34aad358862bbabc775f624d152318dc10e5ee --- /dev/null +++ b/github_page/uci1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb6335f749ac28fd1f6639c7bb2fa03b21eb6a0794ac6b3cc1df6c3b592c09e7 +size 1091153 diff --git a/github_page/uci2.png b/github_page/uci2.png new file mode 100644 index 0000000000000000000000000000000000000000..27088cc6fd9899c43fbdf9d9e70a3e8f94917b4d --- /dev/null +++ b/github_page/uci2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57ad677ff9af49ed2dbe75732af7e9fc9d1b357293f978d94fbb99278a84b500 +size 15100883 diff --git a/github_page/uci3.png b/github_page/uci3.png new file mode 100644 index 0000000000000000000000000000000000000000..01d980b8206e2ed09afedc46c249c0ee5606242e --- /dev/null +++ b/github_page/uci3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ec2c3e416d4ac25b2ec3b82ef08924244fbfff0e7794f193a34fe121d07ed2e +size 1462584 diff --git a/github_page/uci4.png b/github_page/uci4.png new file mode 100644 index 0000000000000000000000000000000000000000..9a83f6c4fa2849c8b4f103a2883cceb01518cefc --- /dev/null +++ b/github_page/uci4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee895a826b206893f7c0cac56bae3442021062df5f0b7eeb7c1bf72419711fb5 +size 7914066 diff --git a/gradio_annotator.py b/gradio_annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1a29ebbec24073a9e4357b700e0577a17a9379 --- /dev/null +++ b/gradio_annotator.py @@ -0,0 +1,160 @@ +import gradio as gr + +from annotator.util import resize_image, HWC3 + + +model_canny = None + + +def canny(img, res, l, h): + img = resize_image(HWC3(img), res) + global model_canny + if model_canny is None: + from annotator.canny import CannyDetector + model_canny = CannyDetector() + result = model_canny(img, l, h) + return [result] + + +model_hed = None + + +def hed(img, res): + img = resize_image(HWC3(img), res) + global model_hed + if model_hed is None: + from annotator.hed import HEDdetector + model_hed = HEDdetector() + result = model_hed(img) + return [result] + + +model_mlsd = None + + +def mlsd(img, res, thr_v, thr_d): + img = resize_image(HWC3(img), res) + global model_mlsd + if model_mlsd is None: + from annotator.mlsd import MLSDdetector + model_mlsd = MLSDdetector() + result = model_mlsd(img, thr_v, thr_d) + return [result] + + +model_midas = None + + +def midas(img, res, a): + img = resize_image(HWC3(img), res) + global model_midas + if model_midas is None: + from annotator.midas import MidasDetector + model_midas = MidasDetector() + results = model_midas(img, a) + return results + + +model_openpose = None + + +def openpose(img, res, has_hand): + img = resize_image(HWC3(img), res) + global model_openpose + if model_openpose is None: + from annotator.openpose import OpenposeDetector + model_openpose = OpenposeDetector() + result, _ = model_openpose(img, has_hand) + return [result] + + +model_uniformer = None + + +def uniformer(img, res): + img = resize_image(HWC3(img), res) + global model_uniformer + if model_uniformer is None: + from annotator.uniformer import UniformerDetector + model_uniformer = UniformerDetector() + result = model_uniformer(img) + return [result] + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Canny Edge") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + low_threshold = gr.Slider(label="low_threshold", minimum=1, maximum=255, value=100, step=1) + high_threshold = gr.Slider(label="high_threshold", minimum=1, maximum=255, value=200, step=1) + resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64) + run_button = gr.Button(label="Run") + with gr.Column(): + gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto") + run_button.click(fn=canny, inputs=[input_image, resolution, low_threshold, high_threshold], outputs=[gallery]) + + with gr.Row(): + gr.Markdown("## HED Edge") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64) + run_button = gr.Button(label="Run") + with gr.Column(): + gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto") + run_button.click(fn=hed, inputs=[input_image, resolution], outputs=[gallery]) + + with gr.Row(): + gr.Markdown("## MLSD Edge") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + value_threshold = gr.Slider(label="value_threshold", minimum=0.01, maximum=2.0, value=0.1, step=0.01) + distance_threshold = gr.Slider(label="distance_threshold", minimum=0.01, maximum=20.0, value=0.1, step=0.01) + resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64) + run_button = gr.Button(label="Run") + with gr.Column(): + gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto") + run_button.click(fn=mlsd, inputs=[input_image, resolution, value_threshold, distance_threshold], outputs=[gallery]) + + with gr.Row(): + gr.Markdown("## MIDAS Depth and Normal") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + alpha = gr.Slider(label="alpha", minimum=0.1, maximum=20.0, value=6.2, step=0.01) + resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64) + run_button = gr.Button(label="Run") + with gr.Column(): + gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto") + run_button.click(fn=midas, inputs=[input_image, resolution, alpha], outputs=[gallery]) + + with gr.Row(): + gr.Markdown("## Openpose") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + hand = gr.Checkbox(label='detect hand', value=False) + resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64) + run_button = gr.Button(label="Run") + with gr.Column(): + gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto") + run_button.click(fn=openpose, inputs=[input_image, resolution, hand], outputs=[gallery]) + + + with gr.Row(): + gr.Markdown("## Uniformer Segmentation") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64) + run_button = gr.Button(label="Run") + with gr.Column(): + gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto") + run_button.click(fn=uniformer, inputs=[input_image, resolution], outputs=[gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_canny2image.py b/gradio_canny2image.py new file mode 100644 index 0000000000000000000000000000000000000000..9866cac5b35925576c20ef4b9ee8b1b1cca235b2 --- /dev/null +++ b/gradio_canny2image.py @@ -0,0 +1,97 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.canny import CannyDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_canny = CannyDetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold): + with torch.no_grad(): + img = resize_image(HWC3(input_image), image_resolution) + H, W, C = img.shape + + detected_map = apply_canny(img, low_threshold, high_threshold) + detected_map = HWC3(detected_map) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [255 - detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Canny Edge Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) + high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_depth2image.py b/gradio_depth2image.py new file mode 100644 index 0000000000000000000000000000000000000000..ee678999ae6033c18a5026bc5f6286d0364c7851 --- /dev/null +++ b/gradio_depth2image.py @@ -0,0 +1,98 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.midas import MidasDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_midas = MidasDetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map, _ = apply_midas(resize_image(input_image, detect_resolution)) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Depth Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_fake_scribble2image.py b/gradio_fake_scribble2image.py new file mode 100644 index 0000000000000000000000000000000000000000..a7cd375f7589c3f7c43b7df91802eb4bf87ea0e0 --- /dev/null +++ b/gradio_fake_scribble2image.py @@ -0,0 +1,102 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.hed import HEDdetector, nms +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_hed = HEDdetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map = apply_hed(resize_image(input_image, detect_resolution)) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [255 - detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Fake Scribble Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_hed2image.py b/gradio_hed2image.py new file mode 100644 index 0000000000000000000000000000000000000000..1ceff67969b7c64a0adcf0557f922c71dd4bfab7 --- /dev/null +++ b/gradio_hed2image.py @@ -0,0 +1,98 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.hed import HEDdetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_hed = HEDdetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map = apply_hed(resize_image(input_image, detect_resolution)) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with HED Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_hough2image.py b/gradio_hough2image.py new file mode 100644 index 0000000000000000000000000000000000000000..6095eeb6767e005a155ee72057b3537021b09f31 --- /dev/null +++ b/gradio_hough2image.py @@ -0,0 +1,100 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.mlsd import MLSDdetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_mlsd = MLSDdetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [255 - cv2.dilate(detected_map, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Hough Line Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="Hough Resolution", minimum=128, maximum=1024, value=512, step=1) + value_threshold = gr.Slider(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01) + distance_threshold = gr.Slider(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_humanpose2image.py b/gradio_humanpose2image.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3e715e68439d6688531cd4c840c2d512092f09 --- /dev/null +++ b/gradio_humanpose2image.py @@ -0,0 +1,115 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.openpose import OpenposeDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler +from aagenerator import Generator + +# apply_openpose = OpenposeDetector() +# model = create_model('./models/cldm_v15.yaml').cpu() +# model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda')) +# model = model.cuda() +# ddim_sampler = DDIMSampler(model) + +model = Generator() + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, + guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution)) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], + "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], + "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( + [strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, + 255).astype( + np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [detected_map] + results + + +def human_gen(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, + guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + seed_everything(seed) + body_numpy, rgb, midas_depth_image, normal_image, rgb2 = model.run(input_image, prompt, steps=ddim_steps) + return [body_numpy, rgb, midas_depth_image, normal_image, rgb2] + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Human Pose") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, + height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, + guess_mode, strength, scale, seed, eta] + run_button.click(fn=human_gen, inputs=ips, outputs=[result_gallery]) + +block.launch(server_name='0.0.0.0', + share=True) diff --git a/gradio_normal2image.py b/gradio_normal2image.py new file mode 100644 index 0000000000000000000000000000000000000000..30aea2f8d4a7956a609cd003f2fee23d2ab162b5 --- /dev/null +++ b/gradio_normal2image.py @@ -0,0 +1,99 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.midas import MidasDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_midas = MidasDetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold): + with torch.no_grad(): + input_image = HWC3(input_image) + _, detected_map = apply_midas(resize_image(input_image, detect_resolution), bg_th=bg_threshold) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Normal Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="Normal Resolution", minimum=128, maximum=1024, value=384, step=1) + bg_threshold = gr.Slider(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_pose2image.py b/gradio_pose2image.py new file mode 100644 index 0000000000000000000000000000000000000000..76f0a0b285b5a9ecdc77c09fe7705c98038b89fb --- /dev/null +++ b/gradio_pose2image.py @@ -0,0 +1,99 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.openpose import OpenposeDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_openpose = OpenposeDetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution)) + detected_map = HWC3(detected_map) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Human Pose") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0', + share=True) diff --git a/gradio_scribble2image.py b/gradio_scribble2image.py new file mode 100644 index 0000000000000000000000000000000000000000..8abbc25bdeaeae23b9032101336682657825ead4 --- /dev/null +++ b/gradio_scribble2image.py @@ -0,0 +1,92 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + img = resize_image(HWC3(input_image), image_resolution) + H, W, C = img.shape + + detected_map = np.zeros_like(img, dtype=np.uint8) + detected_map[np.min(img, axis=2) < 127] = 255 + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [255 - detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Scribble Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_scribble2image_interactive.py b/gradio_scribble2image_interactive.py new file mode 100644 index 0000000000000000000000000000000000000000..7308bcc1bb8387bba10c026495e0dcddae91c2db --- /dev/null +++ b/gradio_scribble2image_interactive.py @@ -0,0 +1,102 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + img = resize_image(HWC3(input_image['mask'][:, :, 0]), image_resolution) + H, W, C = img.shape + + detected_map = np.zeros_like(img, dtype=np.uint8) + detected_map[np.min(img, axis=2) > 127] = 255 + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [255 - detected_map] + results + + +def create_canvas(w, h): + return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Interactive Scribbles") + with gr.Row(): + with gr.Column(): + canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=1) + canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=1) + create_button = gr.Button(label="Start", value='Open drawing canvas!') + input_image = gr.Image(source='upload', type='numpy', tool='sketch') + gr.Markdown(value='Do not forget to change your brush width to make it thinner. (Gradio do not allow developers to set brush width so you need to do it manually.) ' + 'Just click on the small pencil icon in the upper right corner of the above block.') + create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image]) + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/gradio_seg2image.py b/gradio_seg2image.py new file mode 100644 index 0000000000000000000000000000000000000000..c3854dc7624ed6a0a68f059c5001e4973da27587 --- /dev/null +++ b/gradio_seg2image.py @@ -0,0 +1,97 @@ +from share import * +import config + +import cv2 +import einops +import gradio as gr +import numpy as np +import torch +import random + +from pytorch_lightning import seed_everything +from annotator.util import resize_image, HWC3 +from annotator.uniformer import UniformerDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler + + +apply_uniformer = UniformerDetector() + +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict('./models/control_sd15_seg.pth', location='cuda')) +model = model.cuda() +ddim_sampler = DDIMSampler(model) + + +def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): + with torch.no_grad(): + input_image = HWC3(input_image) + detected_map = apply_uniformer(resize_image(input_image, detect_resolution)) + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) + + control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.stack([control for _ in range(num_samples)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 + samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, + shape, cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + return [detected_map] + results + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## Control Stable Diffusion with Segmentation Maps") + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + guess_mode = gr.Checkbox(label='Guess Mode', value=False) + detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1) + ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) + eta = gr.Number(label="eta (DDIM)", value=0.0) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') + n_prompt = gr.Textbox(label="Negative Prompt", + value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) + + +block.launch(server_name='0.0.0.0') diff --git a/ldm/__pycache__/util.cpython-38.pyc b/ldm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4fbd012942e95aad302686d8b38b4a0134636d5 Binary files /dev/null and b/ldm/__pycache__/util.cpython-38.pyc differ diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/util.py b/ldm/data/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c --- /dev/null +++ b/ldm/data/util.py @@ -0,0 +1,24 @@ +import torch + +from ldm.modules.midas.api import load_midas_transform + + +class AddMiDaS(object): + def __init__(self, model_type): + super().__init__() + self.transform = load_midas_transform(model_type) + + def pt2np(self, x): + x = ((x + 1.0) * .5).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x) * 2 - 1. + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = self.pt2np(sample['jpg']) + x = self.transform({"image": x})["image"] + sample['midas_in'] = x + return sample \ No newline at end of file diff --git a/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/ldm/models/__pycache__/autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3827a9a7006f41c9d8d1bd67e29793fa194c0e1 Binary files /dev/null and b/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,219 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config +from ldm.modules.ema import LitEma + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0192327003ad4bcc6ba051dcdbaea00f264e3bd0 Binary files /dev/null and b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ae14ea3c76ec37cc66ac10e2938544b0ea82b7f Binary files /dev/null and b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ diff --git a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1626f50bb5dd0a2f78de3e34e76081cd9155939 Binary files /dev/null and b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144 --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,336 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..f71a44af48c8cba8e97849b7e6813b3e6f9fe83c --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1797 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if reset_ema: assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + if reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.register_buffer('logvar', logvar) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + force_null_conditioning=False, + *args, **kwargs): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + reset_ema = kwargs.pop("reset_ema", False) + reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + if reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, return_x=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: + xc = batch[cond_key] + elif cond_key in ['class_label', 'cls']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, **kwargs + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, + low_scale_config=None, low_scale_key=None, *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08 --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,244 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33 --- /dev/null +++ b/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/ldm/modules/__pycache__/attention.cpython-38.pyc b/ldm/modules/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da2545015189f2d77c57041caa7fe84c5824666 Binary files /dev/null and b/ldm/modules/__pycache__/attention.cpython-38.pyc differ diff --git a/ldm/modules/__pycache__/ema.cpython-38.pyc b/ldm/modules/__pycache__/ema.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b581ba6d0acce1423cac6b3b0171b24cf7478b7 Binary files /dev/null and b/ldm/modules/__pycache__/ema.cpython-38.pyc differ diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..509cd873768f0dd75a75ab3fcdd652822b12b59f --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,341 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d11401223222187d9c370c479e3d9b1272dd0258 Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ diff --git a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d888102240952a6a380097c1c9f327624fb73cca Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ diff --git a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46774977022f56ac1bc2222580d2e8729c6dfef2 Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ diff --git a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0abac2732293e18274a596c96e0b7caa4b8eac9c Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,852 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,786 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988 --- /dev/null +++ b/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,270 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b51b9d91fcd3318267c855b1507b9b535a9e0e0 Binary files /dev/null and b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ diff --git a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b39f71ae0a37a7de2d72fb480eb17c7c738389e7 Binary files /dev/null and b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6e087c8a8d6e656f65dcd65fc6573935d7a339c Binary files /dev/null and b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ diff --git a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7a310f70dc206567864d948e6d0823cb6339100 Binary files /dev/null and b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0 --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,651 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + if up: + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/ldm/modules/midas/__init__.py b/ldm/modules/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/midas/api.py b/ldm/modules/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c --- /dev/null +++ b/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/ldm/modules/midas/midas/__init__.py b/ldm/modules/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/midas/midas/base_model.py b/ldm/modules/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/ldm/modules/midas/midas/blocks.py b/ldm/modules/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/ldm/modules/midas/midas/dpt_depth.py b/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/ldm/modules/midas/midas/midas_net.py b/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/ldm/modules/midas/midas/midas_net_custom.py b/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/ldm/modules/midas/midas/transforms.py b/ldm/modules/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/ldm/modules/midas/midas/vit.py b/ldm/modules/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/ldm/modules/midas/utils.py b/ldm/modules/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..45cb050ece6f401a22dde098ce3f1ff663c5eb6a --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,197 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/mmpose/__init__.py b/mmpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7946470df9a9e16830af885aeee31e3aaee6ca --- /dev/null +++ b/mmpose/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +from .version import __version__, short_version + +mmcv_minimum_version = '2.0.0rc4' +mmcv_maximum_version = '2.1.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.6.0' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version <= digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version <= digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<={mmengine_maximum_version}.' + +__all__ = ['__version__', 'short_version'] diff --git a/mmpose/__pycache__/__init__.cpython-38.pyc b/mmpose/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6684280e5e0d985451799dc57d98757885fde04e Binary files /dev/null and b/mmpose/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/__pycache__/registry.cpython-38.pyc b/mmpose/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f95cca270581e9bf04dc43bfc6ada4b51d0f8b5c Binary files /dev/null and b/mmpose/__pycache__/registry.cpython-38.pyc differ diff --git a/mmpose/__pycache__/version.cpython-38.pyc b/mmpose/__pycache__/version.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3185f2085a525eefb6a80732305e29e842939e24 Binary files /dev/null and b/mmpose/__pycache__/version.cpython-38.pyc differ diff --git a/mmpose/apis/__init__.py b/mmpose/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7149e453552aefe6c3beb35404f281df15943a --- /dev/null +++ b/mmpose/apis/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_bottomup, inference_topdown, init_model +from .inferencers import MMPoseInferencer, Pose2DInferencer + +__all__ = [ + 'init_model', 'inference_topdown', 'inference_bottomup', + 'Pose2DInferencer', 'MMPoseInferencer' +] diff --git a/mmpose/apis/__pycache__/__init__.cpython-38.pyc b/mmpose/apis/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbe08892135200eb55616826d7f1ed87fb0745e7 Binary files /dev/null and b/mmpose/apis/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/apis/__pycache__/inference.cpython-38.pyc b/mmpose/apis/__pycache__/inference.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21aeea460cb62e28037714dc3d0f4362b7a0598 Binary files /dev/null and b/mmpose/apis/__pycache__/inference.cpython-38.pyc differ diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6763d318d5d7fef3f267c235ce6f8210ac6078ce --- /dev/null +++ b/mmpose/apis/inference.py @@ -0,0 +1,225 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from mmengine.config import Config +from mmengine.dataset import Compose, pseudo_collate +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint +from PIL import Image + +from mmpose.datasets.datasets.utils import parse_pose_metainfo +from mmpose.models.builder import build_pose_estimator +from mmpose.structures import PoseDataSample +from mmpose.structures.bbox import bbox_xywh2xyxy + + +def dataset_meta_from_config(config: Config, + dataset_mode: str = 'train') -> Optional[dict]: + """Get dataset metainfo from the model config. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + dataset_mode (str): Specify the dataset of which to get the metainfo. + Options are ``'train'``, ``'val'`` and ``'test'``. Defaults to + ``'train'`` + + Returns: + dict, optional: The dataset metainfo. See + ``mmpose.datasets.datasets.utils.parse_pose_metainfo`` for details. + Return ``None`` if failing to get dataset metainfo from the config. + """ + try: + if dataset_mode == 'train': + dataset_cfg = config.train_dataloader.dataset + elif dataset_mode == 'val': + dataset_cfg = config.val_dataloader.dataset + elif dataset_mode == 'test': + dataset_cfg = config.test_dataloader.dataset + else: + raise ValueError( + f'Invalid dataset {dataset_mode} to get metainfo. ' + 'Should be one of "train", "val", or "test".') + + if 'metainfo' in dataset_cfg: + metainfo = dataset_cfg.metainfo + else: + import mmpose.datasets.datasets # noqa: F401, F403 + from mmpose.registry import DATASETS + + dataset_class = DATASETS.get(dataset_cfg.type) + metainfo = dataset_class.METAINFO + + metainfo = parse_pose_metainfo(metainfo) + + except AttributeError: + metainfo = None + + return metainfo + + +def init_model(config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + cfg_options: Optional[dict] = None) -> nn.Module: + """Initialize a pose estimator from a config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. Defaults to ``None`` + device (str): The device where the anchors will be put on. + Defaults to ``'cuda:0'``. + cfg_options (dict, optional): Options to override some settings in + the used config. Defaults to ``None`` + + Returns: + nn.Module: The constructed pose estimator. + """ + + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None + config.model.train_cfg = None + + # register all modules in mmpose into the registries + init_default_scope(config.get('default_scope', 'mmpose')) + + model = build_pose_estimator(config.model) + model = revert_sync_batchnorm(model) + # get dataset_meta in this priority: checkpoint > config > default (COCO) + dataset_meta = None + + if checkpoint is not None: + ckpt = load_checkpoint(model, checkpoint, map_location='cpu') + + if 'dataset_meta' in ckpt.get('meta', {}): + # checkpoint from mmpose 1.x + dataset_meta = ckpt['meta']['dataset_meta'] + + if dataset_meta is None: + dataset_meta = dataset_meta_from_config(config, dataset_mode='train') + + if dataset_meta is None: + warnings.simplefilter('once') + warnings.warn('Can not load dataset_meta from the checkpoint or the ' + 'model config. Use COCO metainfo by default.') + dataset_meta = parse_pose_metainfo( + dict(from_file='configs/_base_/datasets/coco.py')) + + model.dataset_meta = dataset_meta + + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_topdown(model: nn.Module, + img: Union[np.ndarray, str], + bboxes: Optional[Union[List, np.ndarray]] = None, + bbox_format: str = 'xyxy') -> List[PoseDataSample]: + """Inference image with a top-down pose estimator. + + Args: + model (nn.Module): The top-down pose estimator + img (np.ndarray | str): The loaded image or image file to inference + bboxes (np.ndarray, optional): The bboxes in shape (N, 4), each row + represents a bbox. If not given, the entire image will be regarded + as a single bbox area. Defaults to ``None`` + bbox_format (str): The bbox format indicator. Options are ``'xywh'`` + and ``'xyxy'``. Defaults to ``'xyxy'`` + + Returns: + List[:obj:`PoseDataSample`]: The inference results. Specifically, the + predicted keypoints and scores are saved at + ``data_sample.pred_instances.keypoints`` and + ``data_sample.pred_instances.keypoint_scores``. + """ + init_default_scope(model.cfg.get('default_scope', 'mmpose')) + pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) + + if bboxes is None: + # get bbox from the image size + if isinstance(img, str): + w, h = Image.open(img).size + else: + h, w = img.shape[:2] + + bboxes = np.array([[0, 0, w, h]], dtype=np.float32) + else: + if isinstance(bboxes, list): + bboxes = np.array(bboxes) + + assert bbox_format in {'xyxy', 'xywh'}, \ + f'Invalid bbox_format "{bbox_format}".' + + if bbox_format == 'xywh': + bboxes = bbox_xywh2xyxy(bboxes) + + # construct batch data samples + data_list = [] + for bbox in bboxes: + if isinstance(img, str): + data_info = dict(img_path=img) + else: + data_info = dict(img=img) + data_info['bbox'] = bbox[None] # shape (1, 4) + data_info['bbox_score'] = np.ones(1, dtype=np.float32) # shape (1,) + data_info.update(model.dataset_meta) + data_list.append(pipeline(data_info)) + + if data_list: + # collate data list into a batch, which is a dict with following keys: + # batch['inputs']: a list of input images + # batch['data_samples']: a list of :obj:`PoseDataSample` + batch = pseudo_collate(data_list) + with torch.no_grad(): + results = model.test_step(batch) + else: + results = [] + + return results + + +def inference_bottomup(model: nn.Module, img: Union[np.ndarray, str]): + """Inference image with a bottom-up pose estimator. + + Args: + model (nn.Module): The bottom-up pose estimator + img (np.ndarray | str): The loaded image or image file to inference + + Returns: + List[:obj:`PoseDataSample`]: The inference results. Specifically, the + predicted keypoints and scores are saved at + ``data_sample.pred_instances.keypoints`` and + ``data_sample.pred_instances.keypoint_scores``. + """ + pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) + + # prepare data batch + if isinstance(img, str): + data_info = dict(img_path=img) + else: + data_info = dict(img=img) + data_info.update(model.dataset_meta) + data = pipeline(data_info) + batch = pseudo_collate([data]) + + with torch.no_grad(): + results = model.test_step(batch) + + return results diff --git a/mmpose/apis/inferencers/__init__.py b/mmpose/apis/inferencers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3db192da738d98acca3739a501701024cf2fcb02 --- /dev/null +++ b/mmpose/apis/inferencers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mmpose_inferencer import MMPoseInferencer +from .pose2d_inferencer import Pose2DInferencer +from .utils import get_model_aliases + +__all__ = ['Pose2DInferencer', 'MMPoseInferencer', 'get_model_aliases'] diff --git a/mmpose/apis/inferencers/__pycache__/__init__.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9086cfd4833008cfd13339e4bc768f6ba56c48d Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/__pycache__/base_mmpose_inferencer.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/base_mmpose_inferencer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..931c890f2b84376e361bc6018ac62f7242ddeadb Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/base_mmpose_inferencer.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/__pycache__/mmpose_inferencer.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/mmpose_inferencer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fd9067ddaf07bf67e8700b0ba359643cfd93c45 Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/mmpose_inferencer.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/__pycache__/pose2d_inferencer.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/pose2d_inferencer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..936ad40424e6b4b0361d85063b13b1fc7bd9c411 Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/pose2d_inferencer.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..86e61463b698f7e01942d1b204f92f975ff472d1 --- /dev/null +++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py @@ -0,0 +1,435 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mimetypes +import os +import warnings +from collections import defaultdict +from typing import (Callable, Dict, Generator, Iterable, List, Optional, + Sequence, Union) + +import cv2 +import mmcv +import mmengine +import numpy as np +import torch.nn as nn +from mmengine.config import Config, ConfigDict +from mmengine.dataset import Compose +from mmengine.fileio import (get_file_backend, isdir, join_path, + list_dir_or_file) +from mmengine.infer.infer import BaseInferencer +from mmengine.runner.checkpoint import _load_checkpoint_to_model +from mmengine.structures import InstanceData +from mmengine.utils import mkdir_or_exist + +from mmpose.apis.inference import dataset_meta_from_config +from mmpose.structures import PoseDataSample, split_instances + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ConfigType = Union[Config, ConfigDict] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +class BaseMMPoseInferencer(BaseInferencer): + """The base class for MMPose inferencers.""" + + preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', + 'show', + 'wait_time', + 'draw_bbox', + 'radius', + 'thickness', + 'kpt_thr', + 'vis_out_dir', + } + postprocess_kwargs: set = {'pred_out_dir'} + + def _load_weights_to_model(self, model: nn.Module, + checkpoint: Optional[dict], + cfg: Optional[ConfigType]) -> None: + """Loading model weights and meta information from cfg and checkpoint. + + Subclasses could override this method to load extra meta information + from ``checkpoint`` and ``cfg`` to model. + + Args: + model (nn.Module): Model to load weights and meta information. + checkpoint (dict, optional): The loaded checkpoint. + cfg (Config or ConfigDict, optional): The loaded config. + """ + if checkpoint is not None: + _load_checkpoint_to_model(model, checkpoint) + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmpose 1.x + model.dataset_meta = checkpoint_meta['dataset_meta'] + else: + warnings.warn( + 'dataset_meta are not saved in the checkpoint\'s ' + 'meta data, load via config.') + model.dataset_meta = dataset_meta_from_config( + cfg, dataset_mode='train') + else: + warnings.warn('Checkpoint is not loaded, and the inference ' + 'result is calculated by the randomly initialized ' + 'model!') + model.dataset_meta = dataset_meta_from_config( + cfg, dataset_mode='train') + + def _inputs_to_list(self, inputs: InputsType) -> Iterable: + """Preprocess the inputs to a list. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - str: + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string + according to the task. + + Args: + inputs (InputsType): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + self._video_input = False + + if isinstance(inputs, str): + backend = get_file_backend(inputs) + if hasattr(backend, 'isdir') and isdir(inputs): + # Backends like HttpsBackend do not implement `isdir`, so only + # those backends that implement `isdir` could accept the + # inputs as a directory + filepath_list = [ + join_path(inputs, fname) + for fname in list_dir_or_file(inputs, list_dir=False) + ] + inputs = [] + for filepath in filepath_list: + input_type = mimetypes.guess_type(filepath)[0].split( + '/')[0] + if input_type == 'image': + inputs.append(filepath) + inputs.sort() + else: + # if inputs is a path to a video file, it will be converted + # to a list containing separated frame filenames + input_type = mimetypes.guess_type(inputs)[0].split('/')[0] + if input_type == 'video': + self._video_input = True + video = mmcv.VideoReader(inputs) + self.video_info = dict( + fps=video.fps, + name=os.path.basename(inputs), + writer=None, + predictions=[]) + inputs = video + elif input_type == 'image': + inputs = [inputs] + else: + raise ValueError(f'Expected input to be an image, video, ' + f'or folder, but received {inputs} of ' + f'type {input_type}.') + + elif isinstance(inputs, np.ndarray): + inputs = [inputs] + + return inputs + + def _get_webcam_inputs(self, inputs: str) -> Generator: + """Sets up and returns a generator function that reads frames from a + webcam input. The generator function returns a new frame each time it + is iterated over. + + Args: + inputs (str): A string describing the webcam input, in the format + "webcam:id". + + Returns: + A generator function that yields frames from the webcam input. + + Raises: + ValueError: If the inputs string is not in the expected format. + """ + assert getattr(self.visualizer, 'backend', None) == 'opencv', \ + 'Visualizer must utilize the OpenCV backend in order to ' \ + 'support webcam inputs.' + + # Ensure the inputs string is in the expected format. + inputs = inputs.lower() + assert inputs.startswith('webcam'), f'Expected input to start with ' \ + f'"webcam", but got "{inputs}"' + + # Parse the camera ID from the inputs string. + inputs_ = inputs.split(':') + if len(inputs_) == 1: + camera_id = 0 + elif len(inputs_) == 2 and str.isdigit(inputs_[1]): + camera_id = int(inputs_[1]) + else: + raise ValueError( + f'Expected webcam input to have format "webcam:id", ' + f'but got "{inputs}"') + + # Attempt to open the video capture object. + vcap = cv2.VideoCapture(camera_id) + if not vcap.isOpened(): + warnings.warn(f'Cannot open camera (ID={camera_id})') + return [] + + # Set video input flag and metadata. + self._video_input = True + self.video_info = dict( + fps=10, name='webcam.mp4', writer=None, predictions=[]) + + def _webcam_reader() -> Generator: + while True: + if cv2.waitKey(5) & 0xFF == 27: + vcap.release() + break + + ret_val, frame = vcap.read() + if not ret_val: + break + + yield frame + + return _webcam_reader() + + def _visualization_window_on_close(self, event): + self._window_closing = True + + def _init_pipeline(self, cfg: ConfigType) -> Callable: + """Initialize the test pipeline. + + Args: + cfg (ConfigType): model config path or dict + + Returns: + A pipeline to handle various input data, such as ``str``, + ``np.ndarray``. The returned pipeline will be used to process + a single data. + """ + return Compose(cfg.test_dataloader.dataset.pipeline) + + def preprocess(self, + inputs: InputsType, + batch_size: int = 1, + bboxes: Optional[List] = None, + **kwargs): + """Process the inputs into a model-feedable format. + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``collate_fn``. + List[str or np.ndarray]: List of original inputs in the batch + """ + + for i, input in enumerate(inputs): + bbox = bboxes[i] if bboxes is not None else [] + data_infos = self.preprocess_single( + input, index=i, bboxes=bbox, **kwargs) + # only supports inference with batch size 1 + yield self.collate_fn(data_infos), [input] + + def visualize(self, + inputs: list, + preds: List[PoseDataSample], + return_vis: bool = False, + show: bool = False, + draw_bbox: bool = False, + wait_time: float = 0, + radius: int = 3, + thickness: int = 1, + kpt_thr: float = 0.3, + vis_out_dir: str = '', + window_name: str = '', + window_close_event_handler: Optional[Callable] = None + ) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + return_vis (bool): Whether to return images with predicted results. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (ms). Defaults to 0 + draw_bbox (bool): Whether to draw the bounding boxes. + Defaults to False + radius (int): Keypoint radius for visualization. Defaults to 3 + thickness (int): Link thickness for visualization. Defaults to 1 + kpt_thr (float): The threshold to visualize the keypoints. + Defaults to 0.3 + vis_out_dir (str, optional): Directory to save visualization + results w/o predictions. If left as empty, no file will + be saved. Defaults to ''. + window_name (str, optional): Title of display window. + window_close_event_handler (callable, optional): + + Returns: + List[np.ndarray]: Visualization results. + """ + if (not return_vis) and (not show) and (not vis_out_dir): + return + + if getattr(self, 'visualizer', None) is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + self.visualizer.radius = radius + self.visualizer.line_width = thickness + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img = mmcv.imread(single_input, channel_order='rgb') + elif isinstance(single_input, np.ndarray): + img = mmcv.bgr2rgb(single_input) + else: + raise ValueError('Unsupported input type: ' + f'{type(single_input)}') + + img_name = os.path.basename(pred.metainfo['img_path']) + window_name = window_name if window_name else img_name + + # since visualization and inference utilize the same process, + # the wait time is reduced when a video input is utilized, + # thereby eliminating the issue of inference getting stuck. + wait_time = 1e-5 if self._video_input else wait_time + + visualization = self.visualizer.add_datasample( + window_name, + img, + pred, + draw_gt=False, + draw_bbox=draw_bbox, + draw_heatmap=True, + show=show, + wait_time=wait_time, + kpt_thr=kpt_thr) + results.append(visualization) + + if vis_out_dir: + out_img = mmcv.rgb2bgr(visualization) + + if self._video_input: + + if self.video_info['writer'] is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + mkdir_or_exist(vis_out_dir) + out_file = join_path( + vis_out_dir, + os.path.basename(self.video_info['name'])) + self.video_info['writer'] = cv2.VideoWriter( + out_file, fourcc, self.video_info['fps'], + (visualization.shape[1], visualization.shape[0])) + self.video_info['writer'].write(out_img) + + else: + out_file = join_path(vis_out_dir, img_name) + mmcv.imwrite(out_img, out_file) + + if return_vis: + return results + else: + return [] + + def postprocess( + self, + preds: List[PoseDataSample], + visualization: List[np.ndarray], + return_datasample=False, + pred_out_dir: str = '', + ) -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (np.ndarray): Visualized predictions. + return_datasample (bool): Whether to return results as + datasamples. Defaults to False. + pred_out_dir (str): Directory to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + + result_dict = defaultdict(list) + + result_dict['visualization'] = visualization + for pred in preds: + if not return_datasample: + # convert datasamples to list of instance predictions + pred = split_instances(pred.pred_instances) + result_dict['predictions'].append(pred) + + if pred_out_dir != '': + for pred, data_sample in zip(result_dict['predictions'], preds): + if self._video_input: + self.video_info['predictions'].append(pred) + else: + fname = os.path.splitext( + os.path.basename( + data_sample.metainfo['img_path']))[0] + '.json' + mmengine.dump( + pred, join_path(pred_out_dir, fname), indent=' ') + + return result_dict + + def _finalize_video_processing( + self, + pred_out_dir: str = '', + ): + """Finalize video processing by releasing the video writer and saving + predictions to a file. + + This method should be called after completing the video processing. It + releases the video writer, if it exists, and saves the predictions to a + JSON file if a prediction output directory is provided. + """ + + # Release the video writer if it exists + if self.video_info['writer'] is not None: + self.video_info['writer'].release() + + # Save predictions + if pred_out_dir: + fname = os.path.splitext( + os.path.basename(self.video_info['name']))[0] + '.json' + predictions = [ + dict(frame_id=i, instances=pred) + for i, pred in enumerate(self.video_info['predictions']) + ] + + mmengine.dump( + predictions, join_path(pred_out_dir, fname), indent=' ') diff --git a/mmpose/apis/inferencers/mmpose_inferencer.py b/mmpose/apis/inferencers/mmpose_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5ac222bb3a5c37b89531706c8ce2d94f5f55a1 --- /dev/null +++ b/mmpose/apis/inferencers/mmpose_inferencer.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.config import Config, ConfigDict +from mmengine.infer.infer import ModelType +from mmengine.structures import InstanceData +from rich.progress import track + +from mmpose.structures import PoseDataSample +from .base_mmpose_inferencer import BaseMMPoseInferencer +from .pose2d_inferencer import Pose2DInferencer + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ConfigType = Union[Config, ConfigDict] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +class MMPoseInferencer(BaseMMPoseInferencer): + """MMPose Inferencer. It's a unified inferencer interface for pose + estimation task, currently including: Pose2D. and it can be used to perform + 2D keypoint detection. + + Args: + pose2d (str, optional): Pretrained 2D pose estimation algorithm. + It's the path to the config file or the model name defined in + metafile. For example, it could be: + + - model alias, e.g. ``'body'``, + - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``, + - config path + + Defaults to ``None``. + pose2d_weights (str, optional): Path to the custom checkpoint file of + the selected pose2d model. If it is not specified and "pose2d" is + a model name of metafile, the weights will be loaded from + metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the + available device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to "mmpose". + det_model(str, optional): Config path or alias of detection model. + Defaults to None. + det_weights(str, optional): Path to the checkpoints of detection + model. Defaults to None. + det_cat_ids(int or list[int], optional): Category id for + detection model. Defaults to None. + output_heatmaps (bool, optional): Flag to visualize predicted + heatmaps. If set to None, the default setting from the model + config will be used. Default is None. + """ + + preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', + 'show', + 'wait_time', + 'draw_bbox', + 'radius', + 'thickness', + 'kpt_thr', + 'vis_out_dir', + } + postprocess_kwargs: set = {'pred_out_dir'} + + def __init__(self, + pose2d: Optional[str] = None, + pose2d_weights: Optional[str] = None, + device: Optional[str] = None, + scope: str = 'mmpose', + det_model: Optional[Union[ModelType, str]] = None, + det_weights: Optional[str] = None, + det_cat_ids: Optional[Union[int, List]] = None, + output_heatmaps: Optional[bool] = None) -> None: + + if pose2d is None: + raise ValueError('2d pose estimation algorithm should provided.') + + self.visualizer = None + self.inferencers = dict() + if pose2d is not None: + self.inferencers['pose2d'] = Pose2DInferencer( + pose2d, pose2d_weights, device, scope, det_model, det_weights, + det_cat_ids, output_heatmaps) + + def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): + """Process the inputs into a model-feedable format. + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``collate_fn``. + List[str or np.ndarray]: List of original inputs in the batch + """ + + for i, input in enumerate(inputs): + data_batch = {} + if 'pose2d' in self.inferencers: + data_infos = self.inferencers['pose2d'].preprocess_single( + input, index=i, **kwargs) + data_batch['pose2d'] = self.inferencers['pose2d'].collate_fn( + data_infos) + # only supports inference with batch size 1 + yield data_batch, [input] + + @torch.no_grad() + def forward(self, inputs: InputType, **forward_kwargs) -> PredType: + """Forward the inputs to the model. + + Args: + inputs (InputsType): The inputs to be forwarded. + + Returns: + Dict: The prediction results. Possibly with keys "pose2d". + """ + result = {} + for mode, inferencer in self.inferencers.items(): + result[mode] = inferencer.forward(inputs[mode], **forward_kwargs) + + return result + + def __call__( + self, + inputs: InputsType, + return_datasample: bool = False, + batch_size: int = 1, + out_dir: Optional[str] = None, + **kwargs, + ) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasample (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + out_dir (str, optional): directory to save visualization + results and predictions. Will be overoden if vis_out_dir or + pred_out_dir are given. Defaults to None + **kwargs: Key words arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, + ``visualize_kwargs`` and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + if out_dir is not None: + if 'vis_out_dir' not in kwargs: + kwargs['vis_out_dir'] = f'{out_dir}/visualizations' + if 'pred_out_dir' not in kwargs: + kwargs['pred_out_dir'] = f'{out_dir}/predictions' + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + # preprocessing + if isinstance(inputs, str) and inputs.startswith('webcam'): + inputs = self._get_webcam_inputs(inputs) + batch_size = 1 + if not visualize_kwargs.get('show', False): + warnings.warn('The display mode is closed when using webcam ' + 'input. It will be turned on automatically.') + visualize_kwargs['show'] = True + else: + inputs = self._inputs_to_list(inputs) + + inputs = self.preprocess( + inputs, batch_size=batch_size, **preprocess_kwargs) + + # forward + forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) + for inferencer in self.inferencers.values(): + inferencer._video_input = self._video_input + if self._video_input: + inferencer.video_info = self.video_info + + preds = [] + if 'pose2d' not in self.inferencers or not hasattr( + self.inferencers['pose2d'], 'detector'): + inputs = track(inputs, description='Inference') + + for proc_inputs, ori_inputs in inputs: + preds = self.forward(proc_inputs, **forward_kwargs) + + visualization = self.visualize(ori_inputs, preds, + **visualize_kwargs) + results = self.postprocess(preds, visualization, return_datasample, + **postprocess_kwargs) + yield results + + if self._video_input: + self._finalize_video_processing( + postprocess_kwargs.get('pred_out_dir', '')) + + def visualize(self, inputs: InputsType, preds: PredType, + **kwargs) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + return_vis (bool): Whether to return images with predicted results. + show (bool): Whether to display the image in a popup window. + Defaults to False. + show_interval (int): The interval of show (s). Defaults to 0 + radius (int): Keypoint radius for visualization. Defaults to 3 + thickness (int): Link thickness for visualization. Defaults to 1 + kpt_thr (float): The threshold to visualize the keypoints. + Defaults to 0.3 + vis_out_dir (str, optional): directory to save visualization + results w/o predictions. If left as empty, no file will + be saved. Defaults to ''. + + Returns: + List[np.ndarray]: Visualization results. + """ + + if 'pose2d' in self.inferencers: + window_name = '' + if self._video_input: + window_name = self.video_info['name'] + return self.inferencers['pose2d'].visualize( + inputs, + preds['pose2d'], + window_name=window_name, + window_close_event_handler=self._visualization_window_on_close, + **kwargs) + + def postprocess( + self, + preds: List[PoseDataSample], + visualization: List[np.ndarray], + return_datasample=False, + pred_out_dir: str = '', + ) -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (np.ndarray): Visualized predictions. + return_datasample (bool): Whether to return results as + datasamples. Defaults to False. + pred_out_dir (str): Directory to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + + if 'pose2d' in self.inferencers: + return super().postprocess(preds['pose2d'], visualization, + return_datasample, pred_out_dir) diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..b35abddb19aa35497d5df7da9d2f4c084dd38ecb --- /dev/null +++ b/mmpose/apis/inferencers/pose2d_inferencer.py @@ -0,0 +1,289 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import warnings +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +from mmengine.config import Config, ConfigDict +from mmengine.infer.infer import ModelType +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.structures import InstanceData +from rich.progress import track + +from mmpose.evaluation.functional import nms +from mmpose.registry import DATASETS, INFERENCERS +from mmpose.structures import merge_data_samples +from .base_mmpose_inferencer import BaseMMPoseInferencer +from .utils import default_det_models + +try: + from mmdet.apis.det_inferencer import DetInferencer + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ConfigType = Union[Config, ConfigDict] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +@INFERENCERS.register_module(name='pose-estimation') +@INFERENCERS.register_module() +class Pose2DInferencer(BaseMMPoseInferencer): + """The inferencer for 2D pose estimation. + + Args: + model (str, optional): Pretrained 2D pose estimation algorithm. + It's the path to the config file or the model name defined in + metafile. For example, it could be: + + - model alias, e.g. ``'body'``, + - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``, + - config path + + Defaults to ``None``. + weights (str, optional): Path to the checkpoint. If it is not + specified and "model" is a model name of metafile, the weights + will be loaded from metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the + available device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to "mmpose". + det_model (str, optional): Config path or alias of detection model. + Defaults to None. + det_weights (str, optional): Path to the checkpoints of detection + model. Defaults to None. + det_cat_ids (int or list[int], optional): Category id for + detection model. Defaults to None. + output_heatmaps (bool, optional): Flag to visualize predicted + heatmaps. If set to None, the default setting from the model + config will be used. Default is None. + """ + + preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', + 'show', + 'wait_time', + 'draw_bbox', + 'radius', + 'thickness', + 'kpt_thr', + 'vis_out_dir', + } + postprocess_kwargs: set = {'pred_out_dir'} + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmpose', + det_model: Optional[Union[ModelType, str]] = None, + det_weights: Optional[str] = None, + det_cat_ids: Optional[Union[int, Tuple]] = None, + output_heatmaps: Optional[bool] = None) -> None: + + init_default_scope(scope) + super().__init__( + model=model, weights=weights, device=device, scope=scope) + self.model = revert_sync_batchnorm(self.model) + if output_heatmaps is not None: + self.model.test_cfg['output_heatmaps'] = output_heatmaps + + # assign dataset metainfo to self.visualizer + self.visualizer.set_dataset_meta(self.model.dataset_meta) + + # initialize detector for top-down models + if self.cfg.data_mode == 'topdown': + object_type = DATASETS.get(self.cfg.dataset_type).__module__.split( + 'datasets.')[-1].split('.')[0].lower() + + if det_model in ('whole_image', 'whole-image') or \ + (det_model is None and + object_type not in default_det_models): + self.detector = None + + else: + det_scope = 'mmdet' + if det_model is None: + det_info = default_det_models[object_type] + det_model, det_weights, det_cat_ids = det_info[ + 'model'], det_info['weights'], det_info['cat_ids'] + elif os.path.exists(det_model): + det_cfg = Config.fromfile(det_model) + det_scope = det_cfg.default_scope + + if has_mmdet: + self.detector = DetInferencer( + det_model, det_weights, device=device, scope=det_scope) + else: + raise RuntimeError( + 'MMDetection (v3.0.0 or above) is required to build ' + 'inferencers for top-down pose estimation models.') + + if isinstance(det_cat_ids, (tuple, list)): + self.det_cat_ids = det_cat_ids + else: + self.det_cat_ids = (det_cat_ids, ) + + self._video_input = False + + def preprocess_single(self, + input: InputType, + index: int, + bbox_thr: float = 0.3, + nms_thr: float = 0.3, + bboxes: Union[List[List], List[np.ndarray], + np.ndarray] = []): + """Process a single input into a model-feedable format. + + Args: + input (InputType): Input given by user. + index (int): index of the input + bbox_thr (float): threshold for bounding box detection. + Defaults to 0.3. + nms_thr (float): IoU threshold for bounding box NMS. + Defaults to 0.3. + + Yields: + Any: Data processed by the ``pipeline`` and ``collate_fn``. + """ + + if isinstance(input, str): + data_info = dict(img_path=input) + else: + data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0')) + data_info.update(self.model.dataset_meta) + + if self.cfg.data_mode == 'topdown': + if self.detector is not None: + det_results = self.detector( + input, return_datasample=True)['predictions'] + pred_instance = det_results[0].pred_instances.cpu().numpy() + bboxes = np.concatenate( + (pred_instance.bboxes, pred_instance.scores[:, None]), + axis=1) + + label_mask = np.zeros(len(bboxes), dtype=np.uint8) + for cat_id in self.det_cat_ids: + label_mask = np.logical_or(label_mask, + pred_instance.labels == cat_id) + + bboxes = bboxes[np.logical_and( + label_mask, pred_instance.scores > bbox_thr)] + bboxes = bboxes[nms(bboxes, nms_thr)] + + data_infos = [] + if len(bboxes) > 0: + for bbox in bboxes: + inst = data_info.copy() + inst['bbox'] = bbox[None, :4] + inst['bbox_score'] = bbox[4:5] + data_infos.append(self.pipeline(inst)) + else: + inst = data_info.copy() + + # get bbox from the image size + if isinstance(input, str): + input = mmcv.imread(input) + h, w = input.shape[:2] + + inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32) + inst['bbox_score'] = np.ones(1, dtype=np.float32) + data_infos.append(self.pipeline(inst)) + + else: # bottom-up + data_infos = [self.pipeline(data_info)] + + return data_infos + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], bbox_thr=-1): + data_samples = super().forward(inputs) + if self.cfg.data_mode == 'topdown': + data_samples = [merge_data_samples(data_samples)] + if bbox_thr > 0: + for ds in data_samples: + if 'bbox_scores' in ds.pred_instances: + ds.pred_instances = ds.pred_instances[ + ds.pred_instances.bbox_scores > bbox_thr] + return data_samples + + def __call__( + self, + inputs: InputsType, + return_datasample: bool = False, + batch_size: int = 1, + out_dir: Optional[str] = None, + **kwargs, + ) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasample (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + out_dir (str, optional): directory to save visualization + results and predictions. Will be overoden if vis_out_dir or + pred_out_dir are given. Defaults to None + **kwargs: Key words arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, + ``visualize_kwargs`` and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + if out_dir is not None: + if 'vis_out_dir' not in kwargs: + kwargs['vis_out_dir'] = f'{out_dir}/visualizations' + if 'pred_out_dir' not in kwargs: + kwargs['pred_out_dir'] = f'{out_dir}/predictions' + + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + # preprocessing + if isinstance(inputs, str) and inputs.startswith('webcam'): + inputs = self._get_webcam_inputs(inputs) + batch_size = 1 + if not visualize_kwargs.get('show', False): + warnings.warn('The display mode is closed when using webcam ' + 'input. It will be turned on automatically.') + visualize_kwargs['show'] = True + else: + inputs = self._inputs_to_list(inputs) + + forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) + inputs = self.preprocess( + inputs, batch_size=batch_size, **preprocess_kwargs) + + preds = [] + if not hasattr(self, 'detector'): + inputs = track(inputs, description='Inference') + + for proc_inputs, ori_inputs in inputs: + preds = self.forward(proc_inputs, **forward_kwargs) + + visualization = self.visualize(ori_inputs, preds, + **visualize_kwargs) + results = self.postprocess(preds, visualization, return_datasample, + **postprocess_kwargs) + yield results + + if self._video_input: + self._finalize_video_processing( + postprocess_kwargs.get('pred_out_dir', '')) diff --git a/mmpose/apis/inferencers/utils/__init__.py b/mmpose/apis/inferencers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc40535b0d42a3b2ff41e97e26dcc30c440622b --- /dev/null +++ b/mmpose/apis/inferencers/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .default_det_models import default_det_models +from .get_model_alias import get_model_aliases + +__all__ = ['default_det_models', 'get_model_aliases'] diff --git a/mmpose/apis/inferencers/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/apis/inferencers/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8743eb7e1a7fa72b05cee58b23163c135412f60e Binary files /dev/null and b/mmpose/apis/inferencers/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/utils/__pycache__/default_det_models.cpython-38.pyc b/mmpose/apis/inferencers/utils/__pycache__/default_det_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994aa4db0e13b3ffabe4158bfda1592c1ac4c2bb Binary files /dev/null and b/mmpose/apis/inferencers/utils/__pycache__/default_det_models.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/utils/__pycache__/get_model_alias.cpython-38.pyc b/mmpose/apis/inferencers/utils/__pycache__/get_model_alias.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc157a870b424def4cbac881fbc7283aa1a2d0c4 Binary files /dev/null and b/mmpose/apis/inferencers/utils/__pycache__/get_model_alias.cpython-38.pyc differ diff --git a/mmpose/apis/inferencers/utils/default_det_models.py b/mmpose/apis/inferencers/utils/default_det_models.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb3076880a313109dbfc72e4d9a0b0e3d6a7059 --- /dev/null +++ b/mmpose/apis/inferencers/utils/default_det_models.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from mmengine.config.utils import MODULE2PACKAGE +from mmengine.utils import get_installed_path + +mmpose_path = get_installed_path(MODULE2PACKAGE['mmpose']) + +default_det_models = dict( + human=dict(model='rtmdet-m', weights="/fsx_laion/alvin/pretrain/ViTPose/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth", cat_ids=(0, )), + face=dict( + model=osp.join(mmpose_path, '.mim', + 'demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py'), + weights='/fsx_laion/alvin/pretrain/ViTPose/yolo-x_8xb8-300e_coco-face_13274d7c.pth', + cat_ids=(0, )), + hand=dict( + model=osp.join( + mmpose_path, '.mim', 'demo/mmdetection_cfg/' + 'ssdlite_mobilenetv2_scratch_600e_onehand.py'), + weights='/fsx_laion/alvin/pretrain/ViTPose/ssdlite_mobilenetv2_scratch_600e_onehand-4f9f8686_20220523.pth', + cat_ids=(0, )), + animal=dict( + model='rtmdet-m', + weights=None, + cat_ids=(15, 16, 17, 18, 19, 20, 21, 22, 23)), +) + +default_det_models['body'] = default_det_models['human'] +default_det_models['wholebody'] = default_det_models['human'] diff --git a/mmpose/apis/inferencers/utils/get_model_alias.py b/mmpose/apis/inferencers/utils/get_model_alias.py new file mode 100644 index 0000000000000000000000000000000000000000..49de6528d6ea0df58cf7ae987176defbd4953739 --- /dev/null +++ b/mmpose/apis/inferencers/utils/get_model_alias.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +from mmengine.infer import BaseInferencer + + +def get_model_aliases(scope: str = 'mmpose') -> Dict[str, str]: + """Retrieve model aliases and their corresponding configuration names. + + Args: + scope (str, optional): The scope for the model aliases. Defaults + to 'mmpose'. + + Returns: + Dict[str, str]: A dictionary containing model aliases as keys and + their corresponding configuration names as values. + """ + + # Get a list of model configurations from the metafile + repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope) + model_cfgs = BaseInferencer._get_models_from_metafile(repo_or_mim_dir) + + model_alias_dict = dict() + for model_cfg in model_cfgs: + if 'Alias' in model_cfg: + if isinstance(model_cfg['Alias'], str): + model_alias_dict[model_cfg['Alias']] = model_cfg['Name'] + elif isinstance(model_cfg['Alias'], list): + for alias in model_cfg['Alias']: + model_alias_dict[alias] = model_cfg['Name'] + else: + raise ValueError( + 'encounter an unexpected alias type. Please raise an ' + 'issue at https://github.com/open-mmlab/mmpose/issues ' + 'to announce us') + + return model_alias_dict diff --git a/mmpose/apis/webcam/__init__.py b/mmpose/apis/webcam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..271b238c674479f7120e9f8ef0b1087cb06b42e4 --- /dev/null +++ b/mmpose/apis/webcam/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .webcam_executor import WebcamExecutor + +__all__ = ['WebcamExecutor'] diff --git a/mmpose/apis/webcam/nodes/__init__.py b/mmpose/apis/webcam/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50f7c899d3d5ee2fb06d62cd9fb3963fffbd99de --- /dev/null +++ b/mmpose/apis/webcam/nodes/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_visualizer_node import BaseVisualizerNode +from .helper_nodes import MonitorNode, ObjectAssignerNode, RecorderNode +from .model_nodes import DetectorNode, TopdownPoseEstimatorNode +from .node import Node +from .registry import NODES +from .visualizer_nodes import (BigeyeEffectNode, NoticeBoardNode, + ObjectVisualizerNode, SunglassesEffectNode) + +__all__ = [ + 'BaseVisualizerNode', 'NODES', 'MonitorNode', 'ObjectAssignerNode', + 'RecorderNode', 'DetectorNode', 'TopdownPoseEstimatorNode', 'Node', + 'BigeyeEffectNode', 'NoticeBoardNode', 'ObjectVisualizerNode', + 'ObjectAssignerNode', 'SunglassesEffectNode' +] diff --git a/mmpose/apis/webcam/nodes/base_visualizer_node.py b/mmpose/apis/webcam/nodes/base_visualizer_node.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0ba397d4bf31fc44df86714ce828344d1fdb76 --- /dev/null +++ b/mmpose/apis/webcam/nodes/base_visualizer_node.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Dict, List, Optional, Union + +import numpy as np + +from ..utils import FrameMessage, Message +from .node import Node + + +class BaseVisualizerNode(Node): + """Base class for nodes whose function is to create visual effects, like + visualizing model predictions, showing graphics or showing text messages. + + All subclass should implement the method ``draw()``. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str | list): The name(s) of the output buffer(s). + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + """ + + def __init__(self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True): + + super().__init__(name=name, enable_key=enable_key, enable=enable) + + # Register buffers + self.register_input_buffer(input_buffer, 'input', trigger=True) + self.register_output_buffer(output_buffer) + + def process(self, input_msgs: Dict[str, Message]) -> Union[Message, None]: + input_msg = input_msgs['input'] + + img = self.draw(input_msg) + input_msg.set_image(img) + + return input_msg + + def bypass(self, input_msgs: Dict[str, Message]) -> Union[Message, None]: + return input_msgs['input'] + + @abstractmethod + def draw(self, input_msg: FrameMessage) -> np.ndarray: + """Draw on the frame image of the input FrameMessage. + + Args: + input_msg (:obj:`FrameMessage`): The message of the frame to draw + on + + Returns: + np.array: The processed image. + """ diff --git a/mmpose/apis/webcam/nodes/helper_nodes/__init__.py b/mmpose/apis/webcam/nodes/helper_nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb0ed9dd156bfaeb0f77c78b4afa82d877064cf --- /dev/null +++ b/mmpose/apis/webcam/nodes/helper_nodes/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .monitor_node import MonitorNode +from .object_assigner_node import ObjectAssignerNode +from .recorder_node import RecorderNode + +__all__ = ['MonitorNode', 'ObjectAssignerNode', 'RecorderNode'] diff --git a/mmpose/apis/webcam/nodes/helper_nodes/monitor_node.py b/mmpose/apis/webcam/nodes/helper_nodes/monitor_node.py new file mode 100644 index 0000000000000000000000000000000000000000..305490dc52efd314e8319084b9e806e4a3edb01c --- /dev/null +++ b/mmpose/apis/webcam/nodes/helper_nodes/monitor_node.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import cv2 +import numpy as np +from mmcv import color_val + +from ..node import Node +from ..registry import NODES + +try: + import psutil + psutil_proc = psutil.Process() +except (ImportError, ModuleNotFoundError): + psutil_proc = None + + +@NODES.register_module() +class MonitorNode(Node): + """Show diagnostic information. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + x_offset (int): The position of the text box's left border in + pixels. Default: 20 + y_offset (int): The position of the text box's top border in + pixels. Default: 20 + y_delta (int): The line height in pixels. Default: 15 + text_color (str|tuple): The font color represented in a color name or + a BGR tuple. Default: ``'black'`` + backbround_color (str|tuple): The background color represented in a + color name or a BGR tuple. Default: (255, 183, 0) + text_scale (float): The font scale factor that is multiplied by the + base size. Default: 0.4 + ignore_items (list[str], optional): Specify the node information items + that will not be shown. See ``MonitorNode._default_ignore_items`` + for the default setting. + + Example:: + >>> cfg = dict( + ... type='MonitorNode', + ... name='monitor', + ... enable_key='m', + ... enable=False, + ... input_buffer='vis_notice', + ... output_buffer='display') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + _default_ignore_items = ['timestamp'] + + def __init__(self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = False, + x_offset=20, + y_offset=20, + y_delta=15, + text_color='black', + background_color=(255, 183, 0), + text_scale=0.4, + ignore_items: Optional[List[str]] = None): + super().__init__(name=name, enable_key=enable_key, enable=enable) + + self.x_offset = x_offset + self.y_offset = y_offset + self.y_delta = y_delta + self.text_color = color_val(text_color) + self.background_color = color_val(background_color) + self.text_scale = text_scale + if ignore_items is None: + self.ignore_items = self._default_ignore_items + else: + self.ignore_items = ignore_items + + self.register_input_buffer(input_buffer, 'input', trigger=True) + self.register_output_buffer(output_buffer) + + def process(self, input_msgs): + input_msg = input_msgs['input'] + + input_msg.update_route_info( + node_name='System Info', + node_type='none', + info=self._get_system_info()) + + img = input_msg.get_image() + route_info = input_msg.get_route_info() + img = self._show_route_info(img, route_info) + + input_msg.set_image(img) + return input_msg + + def _get_system_info(self): + """Get the system information including CPU and memory usage. + + Returns: + dict: The system information items. + """ + sys_info = {} + if psutil_proc is not None: + sys_info['CPU(%)'] = psutil_proc.cpu_percent() + sys_info['Memory(%)'] = psutil_proc.memory_percent() + return sys_info + + def _show_route_info(self, img: np.ndarray, + route_info: List[Dict]) -> np.ndarray: + """Show the route information in the frame. + + Args: + img (np.ndarray): The frame image. + route_info (list[dict]): The route information of the frame. + + Returns: + np.ndarray: The processed image. + """ + canvas = np.full(img.shape, self.background_color, dtype=img.dtype) + + x = self.x_offset + y = self.y_offset + + max_len = 0 + + def _put_line(line=''): + nonlocal y, max_len + cv2.putText(canvas, line, (x, y), cv2.FONT_HERSHEY_DUPLEX, + self.text_scale, self.text_color, 1) + y += self.y_delta + max_len = max(max_len, len(line)) + + for node_info in route_info: + title = f'{node_info["node"]}({node_info["node_type"]})' + _put_line(title) + for k, v in node_info['info'].items(): + if k in self.ignore_items: + continue + if isinstance(v, float): + v = f'{v:.1f}' + _put_line(f' {k}: {v}') + + x1 = max(0, self.x_offset) + x2 = min(img.shape[1], int(x + max_len * self.text_scale * 20)) + y1 = max(0, self.y_offset - self.y_delta) + y2 = min(img.shape[0], y) + + src1 = canvas[y1:y2, x1:x2] + src2 = img[y1:y2, x1:x2] + img[y1:y2, x1:x2] = cv2.addWeighted(src1, 0.5, src2, 0.5, 0) + + return img + + def bypass(self, input_msgs): + return input_msgs['input'] diff --git a/mmpose/apis/webcam/nodes/helper_nodes/object_assigner_node.py b/mmpose/apis/webcam/nodes/helper_nodes/object_assigner_node.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a7804ab426fd27f29055aa3d15e0ffdfa359cd --- /dev/null +++ b/mmpose/apis/webcam/nodes/helper_nodes/object_assigner_node.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from typing import List, Union + +from mmpose.utils.timer import RunningAverage +from ..node import Node +from ..registry import NODES + + +@NODES.register_module() +class ObjectAssignerNode(Node): + """Assign the object information to the frame message. + + :class:`ObjectAssignerNode` enables asynchronous processing of model + inference and video I/O, so the video will be captured and displayed + smoothly regardless of the model inference speed. Specifically, + :class:`ObjectAssignerNode` takes messages from both model branch and + video I/O branch as its input, indicated as "object message" and "frame + message" respectively. When an object message arrives it will update the + latest object information; and when a frame message arrives, it will be + assigned with the latest object information and output. + + Specially, if the webcam executor is set to synchrounous mode, the + behavior of :class:`ObjectAssignerNode` will be different: When an object + message arrives, it will trigger an output of itself; and the frame + messages will be ignored. + + Args: + name (str): The node name (also thread name) + frame_buffer (str): Buffer name for frame messages + object_buffer (str): Buffer name for object messages + output_buffer (str): The name(s) of the output buffer(s) + + Example:: + >>> cfg =dict( + ... type='ObjectAssignerNode', + ... name='object assigner', + ... frame_buffer='_frame_', + ... # `_frame_` is an executor-reserved buffer + ... object_buffer='animal_pose', + ... output_buffer='frame') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + def __init__(self, name: str, frame_buffer: str, object_buffer: str, + output_buffer: Union[str, List[str]]): + super().__init__(name=name, enable=True) + self.synchronous = None + + # Cache the latest model result + self.last_object_msg = None + self.last_output_msg = None + + # Inference speed analysis + self.frame_fps = RunningAverage(window=10) + self.frame_lag = RunningAverage(window=10) + self.object_fps = RunningAverage(window=10) + self.object_lag = RunningAverage(window=10) + + # Register buffers + # The trigger buffer depends on the executor.synchronous attribute, + # so it will be set later after the executor is assigned in + # ``set_executor``. + self.register_input_buffer(object_buffer, 'object', trigger=False) + self.register_input_buffer(frame_buffer, 'frame', trigger=False) + self.register_output_buffer(output_buffer) + + def set_executor(self, executor): + super().set_executor(executor) + # Set synchronous according to the executor + if executor.synchronous: + self.synchronous = True + trigger = 'object' + else: + self.synchronous = False + trigger = 'frame' + + # Set trigger input buffer according to the synchronous setting + for buffer_info in self._input_buffers: + if buffer_info.input_name == trigger: + buffer_info.trigger = True + + def process(self, input_msgs): + object_msg = input_msgs['object'] + + # Update last result + if object_msg is not None: + # Update result FPS + if self.last_object_msg is not None: + self.object_fps.update( + 1.0 / + (object_msg.timestamp - self.last_object_msg.timestamp)) + # Update inference latency + self.object_lag.update(time.time() - object_msg.timestamp) + # Update last inference result + self.last_object_msg = object_msg + + if not self.synchronous: + # Asynchronous mode: + # Assign the latest object information to the + # current frame. + frame_msg = input_msgs['frame'] + + self.frame_lag.update(time.time() - frame_msg.timestamp) + + # Assign objects to frame + if self.last_object_msg is not None: + frame_msg.update_objects(self.last_object_msg.get_objects()) + frame_msg.merge_route_info( + self.last_object_msg.get_route_info()) + + output_msg = frame_msg + + else: + # Synchronous mode: + # The current frame will be ignored. Instead, + # the frame from which the latest object information is obtained + # will be used. + self.frame_lag.update(time.time() - object_msg.timestamp) + output_msg = object_msg + + # Update frame fps and lag + if self.last_output_msg is not None: + self.frame_lag.update(time.time() - output_msg.timestamp) + self.frame_fps.update( + 1.0 / (output_msg.timestamp - self.last_output_msg.timestamp)) + self.last_output_msg = output_msg + + return output_msg + + def _get_node_info(self): + info = super()._get_node_info() + info['object_fps'] = self.object_fps.average() + info['object_lag (ms)'] = self.object_lag.average() * 1000 + info['frame_fps'] = self.frame_fps.average() + info['frame_lag (ms)'] = self.frame_lag.average() * 1000 + return info diff --git a/mmpose/apis/webcam/nodes/helper_nodes/recorder_node.py b/mmpose/apis/webcam/nodes/helper_nodes/recorder_node.py new file mode 100644 index 0000000000000000000000000000000000000000..b35a77869248817d34ed6ee499d89dfbc0948c80 --- /dev/null +++ b/mmpose/apis/webcam/nodes/helper_nodes/recorder_node.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from queue import Full, Queue +from threading import Thread +from typing import List, Union + +import cv2 + +from ..node import Node +from ..registry import NODES + + +@NODES.register_module() +class RecorderNode(Node): + """Record the video frames into a local file. + + :class:`RecorderNode` uses OpenCV backend to record the video. Recording + is performed in a separate thread to avoid blocking the data stream. A + buffer queue is used to cached the arrived frame images. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + out_video_file (str): The path of the output video file + out_video_fps (int): The frame rate of the output video. Default: 30 + out_video_codec (str): The codec of the output video. Default: 'mp4v' + buffer_size (int): Size of the buffer queue that caches the arrived + frame images. + enable (bool): Default enable/disable status. Default: ``True``. + + Example:: + >>> cfg = dict( + ... type='RecorderNode', + ... name='recorder', + ... out_video_file='webcam_demo.mp4', + ... input_buffer='display', + ... output_buffer='_display_' + ... # `_display_` is an executor-reserved buffer + ... ) + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + def __init__( + self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + out_video_file: str, + out_video_fps: int = 30, + out_video_codec: str = 'mp4v', + buffer_size: int = 30, + enable: bool = True, + ): + super().__init__(name=name, enable_key=None, enable=enable) + + self.queue = Queue(maxsize=buffer_size) + self.out_video_file = out_video_file + self.out_video_fps = out_video_fps + self.out_video_codec = out_video_codec + self.vwriter = None + + # Register buffers + self.register_input_buffer(input_buffer, 'input', trigger=True) + self.register_output_buffer(output_buffer) + + # Start a new thread to write frame + self.t_record = Thread(target=self._record, args=(), daemon=True) + self.t_record.start() + + def process(self, input_msgs): + + input_msg = input_msgs['input'] + img = input_msg.get_image() if input_msg is not None else None + img_queued = False + + while not img_queued: + try: + self.queue.put(img, timeout=1) + img_queued = True + self.logger.info('Recorder received one frame.') + except Full: + self.logger.warn('Recorder jamed!') + + return input_msg + + def _record(self): + """This method is used to create a thread to get frame images from the + buffer queue and write them into the file.""" + + while True: + + img = self.queue.get() + + if img is None: + break + + if self.vwriter is None: + fourcc = cv2.VideoWriter_fourcc(*self.out_video_codec) + fps = self.out_video_fps + frame_size = (img.shape[1], img.shape[0]) + self.vwriter = cv2.VideoWriter(self.out_video_file, fourcc, + fps, frame_size) + assert self.vwriter.isOpened() + + self.vwriter.write(img) + + self.logger.info('Recorder released.') + if self.vwriter is not None: + self.vwriter.release() + + def on_exit(self): + try: + # Try putting a None into the output queue so the self.vwriter will + # be released after all queue frames have been written to file. + self.queue.put(None, timeout=1) + self.t_record.join(timeout=1) + except Full: + pass + + if self.t_record.is_alive(): + # Force to release self.vwriter + self.logger.warn('Recorder forced release!') + if self.vwriter is not None: + self.vwriter.release() diff --git a/mmpose/apis/webcam/nodes/model_nodes/__init__.py b/mmpose/apis/webcam/nodes/model_nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a116bfec8d7011edeca54ac8a867c6fc0e05c3 --- /dev/null +++ b/mmpose/apis/webcam/nodes/model_nodes/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .detector_node import DetectorNode +from .pose_estimator_node import TopdownPoseEstimatorNode + +__all__ = ['DetectorNode', 'TopdownPoseEstimatorNode'] diff --git a/mmpose/apis/webcam/nodes/model_nodes/detector_node.py b/mmpose/apis/webcam/nodes/model_nodes/detector_node.py new file mode 100644 index 0000000000000000000000000000000000000000..350831fe62a0a53b363bb24ab37ac9b0fe33b260 --- /dev/null +++ b/mmpose/apis/webcam/nodes/model_nodes/detector_node.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import numpy as np + +from mmpose.utils import adapt_mmdet_pipeline +from ...utils import get_config_path +from ..node import Node +from ..registry import NODES + +try: + from mmdet.apis import inference_detector, init_detector + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + + +@NODES.register_module() +class DetectorNode(Node): + """Detect objects from the frame image using MMDetection model. + + Note that MMDetection is required for this node. Please refer to + `MMDetection documentation `_ for the installation guide. + + Parameters: + name (str): The node name (also thread name) + model_cfg (str): The model config file + model_checkpoint (str): The model checkpoint file + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + device (str): Specify the device to hold model weights and inference + the model. Default: ``'cuda:0'`` + bbox_thr (float): Set a threshold to filter out objects with low bbox + scores. Default: 0.5 + multi_input (bool): Whether load all frames in input buffer. If True, + all frames in buffer will be loaded and stacked. The latest frame + is used to detect objects of interest. Default: False + + Example:: + >>> cfg = dict( + ... type='DetectorNode', + ... name='detector', + ... model_config='demo/mmdetection_cfg/' + ... 'ssdlite_mobilenetv2_scratch_600e_coco.py', + ... model_checkpoint='https://download.openmmlab.com' + ... '/mmdetection/v2.0/ssd/' + ... 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_' + ... 'scratch_600e_coco_20210629_110627-974d9307.pth', + ... # `_input_` is an executor-reserved buffer + ... input_buffer='_input_', + ... output_buffer='det_result') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + def __init__(self, + name: str, + model_config: str, + model_checkpoint: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + device: str = 'cuda:0', + bbox_thr: float = 0.5, + multi_input: bool = False): + # Check mmdetection is installed + assert has_mmdet, \ + f'MMDetection is required for {self.__class__.__name__}.' + + super().__init__( + name=name, + enable_key=enable_key, + enable=enable, + multi_input=multi_input) + + self.model_config = get_config_path(model_config, 'mmdet') + self.model_checkpoint = model_checkpoint + self.device = device.lower() + self.bbox_thr = bbox_thr + + # Init model + self.model = init_detector( + self.model_config, self.model_checkpoint, device=self.device) + self.model.cfg = adapt_mmdet_pipeline(self.model.cfg) + + # Register buffers + self.register_input_buffer(input_buffer, 'input', trigger=True) + self.register_output_buffer(output_buffer) + + def bypass(self, input_msgs): + return input_msgs['input'] + + def process(self, input_msgs): + input_msg = input_msgs['input'] + + if self.multi_input: + imgs = [frame.get_image() for frame in input_msg] + input_msg = input_msg[-1] + + img = input_msg.get_image() + + preds = inference_detector(self.model, img) + objects = self._post_process(preds) + input_msg.update_objects(objects) + + if self.multi_input: + input_msg.set_image(np.stack(imgs, axis=0)) + + return input_msg + + def _post_process(self, preds) -> List[Dict]: + """Post-process the predictions of MMDetection model.""" + instances = preds.pred_instances.cpu().numpy() + + classes = self.model.dataset_meta['classes'] + if isinstance(classes, str): + classes = (classes, ) + + objects = [] + for i in range(len(instances)): + if instances.scores[i] < self.bbox_thr: + continue + class_id = instances.labels[i] + obj = { + 'class_id': class_id, + 'label': classes[class_id], + 'bbox': instances.bboxes[i], + 'det_model_cfg': self.model.cfg, + 'dataset_meta': self.model.dataset_meta.copy(), + } + objects.append(obj) + return objects diff --git a/mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py b/mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py new file mode 100644 index 0000000000000000000000000000000000000000..64691cf5606d950ed18053740d1f6aca822a8386 --- /dev/null +++ b/mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +from mmpose.apis import inference_topdown, init_model +from ...utils import get_config_path +from ..node import Node +from ..registry import NODES + + +@dataclass +class TrackInfo: + """Dataclass for object tracking information.""" + next_id: int = 0 + last_objects: List = None + + +@NODES.register_module() +class TopdownPoseEstimatorNode(Node): + """Perform top-down pose estimation using MMPose model. + + The node should be placed after an object detection node. + + Parameters: + name (str): The node name (also thread name) + model_cfg (str): The model config file + model_checkpoint (str): The model checkpoint file + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + device (str): Specify the device to hold model weights and inference + the model. Default: ``'cuda:0'`` + class_ids (list[int], optional): Specify the object category indices + to apply pose estimation. If both ``class_ids`` and ``labels`` + are given, ``labels`` will be ignored. If neither is given, pose + estimation will be applied for all objects. Default: ``None`` + labels (list[str], optional): Specify the object category names to + apply pose estimation. See also ``class_ids``. Default: ``None`` + bbox_thr (float): Set a threshold to filter out objects with low bbox + scores. Default: 0.5 + + Example:: + >>> cfg = dict( + ... type='TopdownPoseEstimatorNode', + ... name='human pose estimator', + ... model_config='configs/wholebody/2d_kpt_sview_rgb_img/' + ... 'topdown_heatmap/coco-wholebody/' + ... 'vipnas_mbv3_coco_wholebody_256x192_dark.py', + ... model_checkpoint='https://download.openmmlab.com/mmpose/' + ... 'top_down/vipnas/vipnas_mbv3_coco_wholebody_256x192_dark' + ... '-e2158108_20211205.pth', + ... labels=['person'], + ... input_buffer='det_result', + ... output_buffer='human_pose') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + def __init__(self, + name: str, + model_config: str, + model_checkpoint: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + device: str = 'cuda:0', + class_ids: Optional[List[int]] = None, + labels: Optional[List[str]] = None, + bbox_thr: float = 0.5): + super().__init__(name=name, enable_key=enable_key, enable=enable) + + # Init model + self.model_config = get_config_path(model_config, 'mmpose') + self.model_checkpoint = model_checkpoint + self.device = device.lower() + + self.class_ids = class_ids + self.labels = labels + self.bbox_thr = bbox_thr + + # Init model + self.model = init_model( + self.model_config, self.model_checkpoint, device=self.device) + + # Register buffers + self.register_input_buffer(input_buffer, 'input', trigger=True) + self.register_output_buffer(output_buffer) + + def bypass(self, input_msgs): + return input_msgs['input'] + + def process(self, input_msgs): + + input_msg = input_msgs['input'] + img = input_msg.get_image() + + if self.class_ids: + objects = input_msg.get_objects( + lambda x: x.get('class_id') in self.class_ids) + elif self.labels: + objects = input_msg.get_objects( + lambda x: x.get('label') in self.labels) + else: + objects = input_msg.get_objects() + + if len(objects) > 0: + # Inference pose + bboxes = np.stack([object['bbox'] for object in objects]) + pose_results = inference_topdown(self.model, img, bboxes) + + # Update objects + for pose_result, object in zip(pose_results, objects): + pred_instances = pose_result.pred_instances + object['keypoints'] = pred_instances.keypoints[0] + object['keypoint_scores'] = pred_instances.keypoint_scores[0] + + dataset_meta = self.model.dataset_meta.copy() + dataset_meta.update(object.get('dataset_meta', dict())) + object['dataset_meta'] = dataset_meta + object['pose_model_cfg'] = self.model.cfg + + input_msg.update_objects(objects) + + return input_msg diff --git a/mmpose/apis/webcam/nodes/node.py b/mmpose/apis/webcam/nodes/node.py new file mode 100644 index 0000000000000000000000000000000000000000..3d34ae1cc081f7a71f5f62397b40ceee7b894503 --- /dev/null +++ b/mmpose/apis/webcam/nodes/node.py @@ -0,0 +1,407 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import time +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from threading import Thread +from typing import Callable, Dict, List, Optional, Tuple, Union + +from mmengine import is_method_overridden + +from mmpose.utils import StopWatch +from ..utils import Message, VideoEndingMessage, limit_max_fps + + +@dataclass +class BufferInfo(): + """Dataclass for buffer information.""" + buffer_name: str + input_name: Optional[str] = None + trigger: bool = False + + +@dataclass +class EventInfo(): + """Dataclass for event handler information.""" + event_name: str + is_keyboard: bool = False + handler_func: Optional[Callable] = None + + +class Node(Thread, metaclass=ABCMeta): + """Base class for node, which is the interface of basic function module. + + :class:`Node` inherits :class:`threading.Thread`. All subclasses should + override following methods: + + - ``process()`` + - ``bypass()`` (optional) + + + Parameters: + name (str): The node name (also thread name) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + max_fps (int): Maximum FPS of the node. This is to avoid the node + running unrestrictedly and causing large resource consuming. + Default: 30 + input_check_interval (float): Minimum interval (in millisecond) between + checking if input is ready. Default: 0.001 + enable (bool): Default enable/disable status. Default: ``True`` + daemon (bool): Whether node is a daemon. Default: ``True`` + multi_input (bool): Whether load all messages in buffer. If False, + only one message will be loaded each time. Default: ``False`` + """ + + def __init__(self, + name: str, + enable_key: Optional[Union[str, int]] = None, + max_fps: int = 30, + input_check_interval: float = 0.01, + enable: bool = True, + daemon: bool = False, + multi_input: bool = False): + super().__init__(name=name, daemon=daemon) + self._executor = None + self._enabled = enable + self.enable_key = enable_key + self.max_fps = max_fps + self.input_check_interval = input_check_interval + self.multi_input = multi_input + + # A partitioned buffer manager the executor's buffer manager that + # only accesses the buffers related to the node + self._buffer_manager = None + + # Input/output buffers are a list of registered buffers' information + self._input_buffers = [] + self._output_buffers = [] + + # Event manager is a copy of assigned executor's event manager + self._event_manager = None + + # A list of registered event information + # See register_event() for more information + # Note that we recommend to handle events in nodes by registering + # handlers, but one can still access the raw event by _event_manager + self._registered_events = [] + + # A list of (listener_threads, event_info) + # See set_executor() for more information + self._event_listener_threads = [] + + # A timer to calculate node FPS + self._timer = StopWatch(window=10) + + # Register enable toggle key + if self.enable_key: + # If the node allows toggling enable, it should override the + # `bypass` method to define the node behavior when disabled. + if not is_method_overridden('bypass', Node, self.__class__): + raise NotImplementedError( + f'The node {self.__class__} does not support toggling' + 'enable but got argument `enable_key`. To support toggling' + 'enable, please override the `bypass` method of the node.') + + self.register_event( + event_name=self.enable_key, + is_keyboard=True, + handler_func=self._toggle_enable, + ) + + # Logger + self.logger = logging.getLogger(f'Node "{self.name}"') + + @property + def registered_buffers(self): + return self._input_buffers + self._output_buffers + + @property + def registered_events(self): + return self._registered_events.copy() + + def _toggle_enable(self): + self._enabled = not self._enabled + + def register_input_buffer(self, + buffer_name: str, + input_name: str, + trigger: bool = False): + """Register an input buffer, so that Node can automatically check if + data is ready, fetch data from the buffers and format the inputs to + feed into `process` method. + + The subclass of Node should invoke `register_input_buffer` in its + `__init__` method. This method can be invoked multiple times to + register multiple input buffers. + + Args: + buffer_name (str): The name of the buffer + input_name (str): The name of the fetched message from the + corresponding buffer + trigger (bool): An trigger input means the node will wait + until the input is ready before processing. Otherwise, an + inessential input will not block the processing, instead + a None will be fetched if the buffer is not ready. + """ + buffer_info = BufferInfo(buffer_name, input_name, trigger) + self._input_buffers.append(buffer_info) + + def register_output_buffer(self, buffer_name: Union[str, List[str]]): + """Register one or multiple output buffers, so that the Node can + automatically send the output of the `process` method to these buffers. + + The subclass of Node should invoke `register_output_buffer` in its + `__init__` method. + + Args: + buffer_name (str|list): The name(s) of the output buffer(s). + """ + + if not isinstance(buffer_name, list): + buffer_name = [buffer_name] + + for name in buffer_name: + buffer_info = BufferInfo(name) + self._output_buffers.append(buffer_info) + + def register_event(self, + event_name: str, + is_keyboard: bool = False, + handler_func: Optional[Callable] = None): + """Register an event. All events used in the node need to be registered + in __init__(). If a callable handler is given, a thread will be create + to listen and handle the event when the node starts. + + Args: + Args: + event_name (str|int): The event name. If is_keyboard==True, + event_name should be a str (as char) or an int (as ascii) + is_keyboard (bool): Indicate whether it is an keyboard + event. If True, the argument event_name will be regarded as a + key indicator. + handler_func (callable, optional): The event handler function, + which should be a collable object with no arguments or + return values. Default: ``None``. + """ + event_info = EventInfo(event_name, is_keyboard, handler_func) + self._registered_events.append(event_info) + + def set_executor(self, executor): + """Assign the node to an executor so the node can access the buffers + and event manager of the executor. + + This method should be invoked by the executor instance. + + Args: + executor (:obj:`WebcamExecutor`): The executor to hold the node + """ + # Get partitioned buffer manager + buffer_names = [ + buffer.buffer_name + for buffer in self._input_buffers + self._output_buffers + ] + self._buffer_manager = executor.buffer_manager.get_sub_manager( + buffer_names) + + # Get event manager + self._event_manager = executor.event_manager + + def _get_input_from_buffer(self) -> Tuple[bool, Optional[Dict]]: + """Get and pack input data. + + The function returns a tuple (status, data). If the trigger buffers + are ready, the status flag will be True, and the packed data is a dict + whose items are buffer names and corresponding messages (unready + non-trigger buffers will give a `None`). Otherwise, the status flag is + False and the packed data is None. + + Returns: + tuple[bool, dict]: The first item is a bool value indicating + whether input is ready (i.e., all tirgger buffers are ready). The + second value is a dict of buffer names and messages. + """ + buffer_manager = self._buffer_manager + + if buffer_manager is None: + raise ValueError(f'Node "{self.name}": not set to an executor.') + + # Check that trigger buffers are ready + for buffer_info in self._input_buffers: + if buffer_info.trigger and buffer_manager.is_empty( + buffer_info.buffer_name): + return False, None + + # Default input + result = { + buffer_info.input_name: None + for buffer_info in self._input_buffers + } + + for buffer_info in self._input_buffers: + + while not buffer_manager.is_empty(buffer_info.buffer_name): + msg = buffer_manager.get(buffer_info.buffer_name, block=False) + if self.multi_input: + if result[buffer_info.input_name] is None: + result[buffer_info.input_name] = [] + result[buffer_info.input_name].append(msg) + else: + result[buffer_info.input_name] = msg + break + + # Return unsuccessful flag if any trigger input is unready + if buffer_info.trigger and result[buffer_info.input_name] is None: + return False, None + + return True, result + + def _send_output_to_buffers(self, output_msg): + """Send output of ``process()`` to the registered output buffers. + + Args: + output_msg (Message): output message + """ + for buffer_info in self._output_buffers: + buffer_name = buffer_info.buffer_name + self._buffer_manager.put_force(buffer_name, output_msg) + + @abstractmethod + def process(self, input_msgs: Dict[str, Message]) -> Union[Message, None]: + """The method that implements the function of the node. + + This method will be invoked when the node is enabled and the input + data is ready. All subclasses of Node should override this method. + + Args: + input_msgs (dict[str, :obj:`Message`]): The input data collected + from the buffers. For each item, the key is the `input_name` + of the registered input buffer, and the value is a Message + instance fetched from the buffer (or None if the buffer is + non-trigger and not ready). + + Returns: + Message: The output message of the node which will be send to all + registered output buffers. + """ + + def bypass(self, input_msgs: Dict[str, Message]) -> Union[Message, None]: + """The method that defines the node behavior when disabled. + + Note that a node must override this method if it has `enable_key`. + This method has the same signature as ``process()``. + + Args: + input_msgs (dict[str, :obj:`Message`]): The input data collected + from the buffers. For each item, the key is the `input_name` + of the registered input buffer, and the value is a Message + instance fetched from the buffer (or None if the buffer is + non-trigger and not ready). + + Returns: + Message: The output message of the node which will be send to all + registered output buffers. + """ + raise NotImplementedError + + def _get_node_info(self) -> Dict: + """Get route information of the node. + + Default information includes: + - ``'fps'``: The processing speed of the node + - ``'timestamp'``: The time that this method is invoked + + Subclasses can override this method to customize the node information. + + Returns: + dict: The items of node information + """ + info = {'fps': self._timer.report('_FPS_'), 'timestamp': time.time()} + return info + + def on_exit(self): + """This method will be invoked on event `_exit_`. + + Subclasses should override this method to specifying the exiting + behavior. + """ + + def run(self): + """Method representing the Node's activity. + + This method override the standard ``run()`` method of Thread. + Subclasses of :class:`Node` should not override this method in + subclasses. + """ + + self.logger.info('Process starts.') + + # Create event listener threads + for event_info in self._registered_events: + + if event_info.handler_func is None: + continue + + def event_listener(): + while True: + with self._event_manager.wait_and_handle( + event_info.event_name, event_info.is_keyboard): + event_info.handler_func() + + t_listener = Thread(target=event_listener, args=(), daemon=True) + t_listener.start() + self._event_listener_threads.append(t_listener) + + # Loop + while True: + # Exit + if self._event_manager.is_set('_exit_'): + self.on_exit() + break + + # Check if input is ready + input_status, input_msgs = self._get_input_from_buffer() + + # Input is not ready + if not input_status: + time.sleep(self.input_check_interval) + continue + + # If a VideoEndingMessage is received, broadcast the signal + # without invoking process() or bypass() + video_ending = False + for _, msg in input_msgs.items(): + if isinstance(msg, VideoEndingMessage): + self._send_output_to_buffers(msg) + video_ending = True + break + + if video_ending: + self.on_exit() + break + + # Check if enabled + if not self._enabled: + # Override bypass method to define node behavior when disabled + output_msg = self.bypass(input_msgs) + else: + with self._timer.timeit(): + with limit_max_fps(self.max_fps): + # Process + output_msg = self.process(input_msgs) + + if output_msg: + # Update route information + node_info = self._get_node_info() + output_msg.update_route_info(node=self, info=node_info) + + # Send output message + if output_msg is not None: + self._send_output_to_buffers(output_msg) + + self.logger.info('Process ends.') diff --git a/mmpose/apis/webcam/nodes/registry.py b/mmpose/apis/webcam/nodes/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..06d39fed63bb1972d5a59892c8d3d208f113e792 --- /dev/null +++ b/mmpose/apis/webcam/nodes/registry.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.registry import Registry + +NODES = Registry('node') diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py b/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fad7e3037673fb2ef01fdcd709fb9e2212ecef5a --- /dev/null +++ b/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bigeye_effect_node import BigeyeEffectNode +from .notice_board_node import NoticeBoardNode +from .object_visualizer_node import ObjectVisualizerNode +from .sunglasses_effect_node import SunglassesEffectNode + +__all__ = [ + 'ObjectVisualizerNode', 'NoticeBoardNode', 'SunglassesEffectNode', + 'BigeyeEffectNode' +] diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/bigeye_effect_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/bigeye_effect_node.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbec3d670b90528dd234ac57fb792e379e8a9f5 --- /dev/null +++ b/mmpose/apis/webcam/nodes/visualizer_nodes/bigeye_effect_node.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import groupby +from typing import Dict, List, Optional, Union + +import cv2 +import numpy as np + +from ...utils import get_eye_keypoint_ids +from ..base_visualizer_node import BaseVisualizerNode +from ..registry import NODES + + +@NODES.register_module() +class BigeyeEffectNode(BaseVisualizerNode): + """Apply big-eye effect to the objects with eye keypoints in the frame. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + kpt_thr (float): The score threshold of valid keypoints. Default: 0.5 + + Example:: + >>> cfg = dict( + ... type='SunglassesEffectNode', + ... name='sunglasses', + ... enable_key='s', + ... enable=False, + ... input_buffer='vis', + ... output_buffer='vis_sunglasses') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + def __init__(self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + kpt_thr: float = 0.5): + + super().__init__( + name=name, + input_buffer=input_buffer, + output_buffer=output_buffer, + enable_key=enable_key, + enable=enable) + self.kpt_thr = kpt_thr + + def draw(self, input_msg): + canvas = input_msg.get_image() + + objects = input_msg.get_objects(lambda x: + ('keypoints' in x and 'bbox' in x)) + + for dataset_meta, group in groupby(objects, + lambda x: x['dataset_meta']): + left_eye_index, right_eye_index = get_eye_keypoint_ids( + dataset_meta) + canvas = self.apply_bigeye_effect(canvas, group, left_eye_index, + right_eye_index) + return canvas + + def apply_bigeye_effect(self, canvas: np.ndarray, objects: List[Dict], + left_eye_index: int, + right_eye_index: int) -> np.ndarray: + """Apply big-eye effect. + + Args: + canvas (np.ndarray): The image to apply the effect + objects (list[dict]): The object list with bbox and keypoints + - "bbox" ([K, 4(or 5)]): bbox in [x1, y1, x2, y2, (score)] + - "keypoints" ([K,3]): keypoints in [x, y, score] + left_eye_index (int): Keypoint index of left eye + right_eye_index (int): Keypoint index of right eye + + Returns: + np.ndarray: Processed image. + """ + + xx, yy = np.meshgrid( + np.arange(canvas.shape[1]), np.arange(canvas.shape[0])) + xx = xx.astype(np.float32) + yy = yy.astype(np.float32) + + for obj in objects: + bbox = obj['bbox'] + kpts = obj['keypoints'] + kpt_scores = obj['keypoint_scores'] + + if kpt_scores[left_eye_index] < self.kpt_thr or kpt_scores[ + right_eye_index] < self.kpt_thr: + continue + + kpt_leye = kpts[left_eye_index, :2] + kpt_reye = kpts[right_eye_index, :2] + for xc, yc in [kpt_leye, kpt_reye]: + + # distortion parameters + k1 = 0.001 + epe = 1e-5 + + scale = (bbox[2] - bbox[0])**2 + (bbox[3] - bbox[1])**2 + r2 = ((xx - xc)**2 + (yy - yc)**2) + r2 = (r2 + epe) / scale # normalized by bbox scale + + xx = (xx - xc) / (1 + k1 / r2) + xc + yy = (yy - yc) / (1 + k1 / r2) + yc + + canvas = cv2.remap( + canvas, + xx, + yy, + interpolation=cv2.INTER_AREA, + borderMode=cv2.BORDER_REPLICATE) + + return canvas diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/notice_board_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/notice_board_node.py new file mode 100644 index 0000000000000000000000000000000000000000..0578ec38eb399645eef7795577de29f4fb7240b9 --- /dev/null +++ b/mmpose/apis/webcam/nodes/visualizer_nodes/notice_board_node.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np +from mmcv import color_val + +from ...utils import FrameMessage +from ..base_visualizer_node import BaseVisualizerNode +from ..registry import NODES + + +@NODES.register_module() +class NoticeBoardNode(BaseVisualizerNode): + """Show text messages in the frame. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + content_lines (list[str], optional): The lines of text message to show + in the frame. If not given, a default message will be shown. + Default: ``None`` + x_offset (int): The position of the notice board's left border in + pixels. Default: 20 + y_offset (int): The position of the notice board's top border in + pixels. Default: 20 + y_delta (int): The line height in pixels. Default: 15 + text_color (str|tuple): The font color represented in a color name or + a BGR tuple. Default: ``'black'`` + backbround_color (str|tuple): The background color represented in a + color name or a BGR tuple. Default: (255, 183, 0) + text_scale (float): The font scale factor that is multiplied by the + base size. Default: 0.4 + + Example:: + >>> cfg = dict( + ... type='NoticeBoardNode', + ... name='instruction', + ... enable_key='h', + ... enable=True, + ... input_buffer='vis_bigeye', + ... output_buffer='vis_notice', + ... content_lines=[ + ... 'This is a demo for pose visualization and simple image ' + ... 'effects. Have fun!', '', 'Hot-keys:', + ... '"v": Pose estimation result visualization', + ... '"s": Sunglasses effect B-)', '"b": Big-eye effect 0_0', + ... '"h": Show help information', + ... '"m": Show diagnostic information', '"q": Exit' + ... ], + ... ) + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + default_content_lines = ['This is a notice board!'] + + def __init__(self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + content_lines: Optional[List[str]] = None, + x_offset: int = 20, + y_offset: int = 20, + y_delta: int = 15, + text_color: Union[str, Tuple[int, int, int]] = 'black', + background_color: Union[str, Tuple[int, int, + int]] = (255, 183, 0), + text_scale: float = 0.4): + super().__init__( + name=name, + input_buffer=input_buffer, + output_buffer=output_buffer, + enable_key=enable_key, + enable=enable) + + self.x_offset = x_offset + self.y_offset = y_offset + self.y_delta = y_delta + self.text_color = color_val(text_color) + self.background_color = color_val(background_color) + self.text_scale = text_scale + + if content_lines: + self.content_lines = content_lines + else: + self.content_lines = self.default_content_lines + + def draw(self, input_msg: FrameMessage) -> np.ndarray: + img = input_msg.get_image() + canvas = np.full(img.shape, self.background_color, dtype=img.dtype) + + x = self.x_offset + y = self.y_offset + + max_len = max([len(line) for line in self.content_lines]) + + def _put_line(line=''): + nonlocal y + cv2.putText(canvas, line, (x, y), cv2.FONT_HERSHEY_DUPLEX, + self.text_scale, self.text_color, 1) + y += self.y_delta + + for line in self.content_lines: + _put_line(line) + + x1 = max(0, self.x_offset) + x2 = min(img.shape[1], int(x + max_len * self.text_scale * 20)) + y1 = max(0, self.y_offset - self.y_delta) + y2 = min(img.shape[0], y) + + src1 = canvas[y1:y2, x1:x2] + src2 = img[y1:y2, x1:x2] + img[y1:y2, x1:x2] = cv2.addWeighted(src1, 0.5, src2, 0.5, 0) + + return img diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py new file mode 100644 index 0000000000000000000000000000000000000000..ef28a0804cd86352700f9636bd97751b31383ce1 --- /dev/null +++ b/mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py @@ -0,0 +1,341 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import groupby +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import mmcv +import numpy as np + +from ...utils import FrameMessage +from ..base_visualizer_node import BaseVisualizerNode +from ..registry import NODES + + +def imshow_bboxes(img, + bboxes, + labels=None, + colors='green', + text_color='white', + thickness=1, + font_scale=0.5): + """Draw bboxes with labels (optional) on an image. This is a wrapper of + mmcv.imshow_bboxes. + + Args: + img (str or ndarray): The image to be displayed. + bboxes (ndarray): ndarray of shape (k, 4), each row is a bbox in + format [x1, y1, x2, y2]. + labels (str or list[str], optional): labels of each bbox. + colors (list[str or tuple or :obj:`Color`]): A list of colors. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + + Returns: + ndarray: The image with bboxes drawn on it. + """ + + # adapt to mmcv.imshow_bboxes input format + bboxes = np.split( + bboxes, bboxes.shape[0], axis=0) if bboxes.shape[0] > 0 else [] + if not isinstance(colors, list): + colors = [colors for _ in range(len(bboxes))] + colors = [mmcv.color_val(c) for c in colors] + assert len(bboxes) == len(colors) + + img = mmcv.imshow_bboxes( + img, + bboxes, + colors, + top_k=-1, + thickness=thickness, + show=False, + out_file=None) + + if labels is not None: + if not isinstance(labels, list): + labels = [labels for _ in range(len(bboxes))] + assert len(labels) == len(bboxes) + + for bbox, label, color in zip(bboxes, labels, colors): + if label is None: + continue + bbox_int = bbox[0, :4].astype(np.int32) + # roughly estimate the proper font size + text_size, text_baseline = cv2.getTextSize(label, + cv2.FONT_HERSHEY_DUPLEX, + font_scale, thickness) + text_x1 = bbox_int[0] + text_y1 = max(0, bbox_int[1] - text_size[1] - text_baseline) + text_x2 = bbox_int[0] + text_size[0] + text_y2 = text_y1 + text_size[1] + text_baseline + cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color, + cv2.FILLED) + cv2.putText(img, label, (text_x1, text_y2 - text_baseline), + cv2.FONT_HERSHEY_DUPLEX, font_scale, + mmcv.color_val(text_color), thickness) + + return img + + +def imshow_keypoints(img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False): + """Draw keypoints and links on an image. + + Args: + img (str or Tensor): The image to draw poses on. If an image array + is given, id will be modified in-place. + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.array[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + """ + + img = mmcv.imread(img) + img_h, img_w, _ = img.shape + + for kpts in pose_result: + + kpts = np.array(kpts, copy=False) + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + + for kid, kpt in enumerate(kpts): + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + + if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None: + # skip the point that should not be drawn + continue + + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius, + color, -1) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, + color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + + if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 + or pos1[1] >= img_h or pos2[0] <= 0 or pos2[0] >= img_w + or pos2[1] <= 0 or pos2[1] >= img_h + or kpts[sk[0], 2] < kpt_score_thr + or kpts[sk[1], 2] < kpt_score_thr + or pose_link_color[sk_id] is None): + # skip the link that should not be drawn + continue + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), (int(length / 2), int(stickwidth)), + int(angle), 0, 360, 1) + cv2.fillConvexPoly(img_copy, polygon, color) + transparency = max( + 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img + + +@NODES.register_module() +class ObjectVisualizerNode(BaseVisualizerNode): + """Visualize the bounding box and keypoints of objects. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: (1) If ``enable_key`` is set, + the ``bypass()`` method need to be overridden to define the node + behavior when disabled; (2) Some hot-keys are reserved for + particular use. For example: 'q', 'Q' and 27 are used for exiting. + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True`` + show_bbox (bool): Set ``True`` to show the bboxes of detection + objects. Default: ``True`` + show_keypoint (bool): Set ``True`` to show the pose estimation + results. Default: ``True`` + must_have_bbox (bool): Only show objects with keypoints. + Default: ``False`` + kpt_thr (float): The threshold of keypoint score. Default: 0.3 + radius (int): The radius of keypoint. Default: 4 + thickness (int): The thickness of skeleton. Default: 2 + bbox_color (str|tuple|dict): The color of bboxes. If a single color is + given (a str like 'green' or a BGR tuple like (0, 255, 0)), it + will be used for all bboxes. If a dict is given, it will be used + as a map from class labels to bbox colors. If not given, a default + color map will be used. Default: ``None`` + + Example:: + >>> cfg = dict( + ... type='ObjectVisualizerNode', + ... name='object visualizer', + ... enable_key='v', + ... enable=True, + ... show_bbox=True, + ... must_have_keypoint=False, + ... show_keypoint=True, + ... input_buffer='frame', + ... output_buffer='vis') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + default_bbox_color = { + 'person': (148, 139, 255), + 'cat': (255, 255, 0), + 'dog': (255, 255, 0), + } + + def __init__(self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + show_bbox: bool = True, + show_keypoint: bool = True, + must_have_keypoint: bool = False, + kpt_thr: float = 0.3, + radius: int = 4, + thickness: int = 2, + bbox_color: Optional[Union[str, Tuple, Dict]] = 'green'): + + super().__init__( + name=name, + input_buffer=input_buffer, + output_buffer=output_buffer, + enable_key=enable_key, + enable=enable) + + self.kpt_thr = kpt_thr + self.bbox_color = bbox_color + self.show_bbox = show_bbox + self.show_keypoint = show_keypoint + self.must_have_keypoint = must_have_keypoint + self.radius = radius + self.thickness = thickness + + def _draw_bbox(self, canvas: np.ndarray, input_msg: FrameMessage): + """Draw object bboxes.""" + + if self.must_have_keypoint: + objects = input_msg.get_objects( + lambda x: 'bbox' in x and 'keypoints' in x) + else: + objects = input_msg.get_objects(lambda x: 'bbox' in x) + # return if there is no detected objects + if not objects: + return canvas + + bboxes = [obj['bbox'] for obj in objects] + labels = [obj.get('label', None) for obj in objects] + default_color = (0, 255, 0) + + # Get bbox colors + if isinstance(self.bbox_color, dict): + colors = [ + self.bbox_color.get(label, default_color) for label in labels + ] + else: + colors = self.bbox_color + + imshow_bboxes( + canvas, + np.vstack(bboxes), + labels=labels, + colors=colors, + text_color='white', + font_scale=0.5) + + return canvas + + def _draw_keypoint(self, canvas: np.ndarray, input_msg: FrameMessage): + """Draw object keypoints.""" + objects = input_msg.get_objects(lambda x: 'pose_model_cfg' in x) + + # return if there is no object with keypoints + if not objects: + return canvas + + for model_cfg, group in groupby(objects, + lambda x: x['pose_model_cfg']): + dataset_info = objects[0]['dataset_meta'] + keypoints = [ + np.concatenate( + (obj['keypoints'], obj['keypoint_scores'][:, None]), + axis=1) for obj in group + ] + imshow_keypoints( + canvas, + keypoints, + skeleton=dataset_info['skeleton_links'], + kpt_score_thr=self.kpt_thr, + pose_kpt_color=dataset_info['keypoint_colors'], + pose_link_color=dataset_info['skeleton_link_colors'], + radius=self.radius, + thickness=self.thickness) + + return canvas + + def draw(self, input_msg: FrameMessage) -> np.ndarray: + canvas = input_msg.get_image() + + if self.show_bbox: + canvas = self._draw_bbox(canvas, input_msg) + + if self.show_keypoint: + canvas = self._draw_keypoint(canvas, input_msg) + + return canvas diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/sunglasses_effect_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/sunglasses_effect_node.py new file mode 100644 index 0000000000000000000000000000000000000000..7c011177f57c95a959bc97f167eb58229707b641 --- /dev/null +++ b/mmpose/apis/webcam/nodes/visualizer_nodes/sunglasses_effect_node.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import groupby +from typing import Dict, List, Optional, Union + +import cv2 +import numpy as np + +from ...utils import get_eye_keypoint_ids, load_image_from_disk_or_url +from ..base_visualizer_node import BaseVisualizerNode +from ..registry import NODES + + +@NODES.register_module() +class SunglassesEffectNode(BaseVisualizerNode): + """Apply sunglasses effect (draw sunglasses at the facial area)to the + objects with eye keypoints in the frame. + + Args: + name (str): The node name (also thread name) + input_buffer (str): The name of the input buffer + output_buffer (str|list): The name(s) of the output buffer(s) + enable_key (str|int, optional): Set a hot-key to toggle enable/disable + of the node. If an int value is given, it will be treated as an + ascii code of a key. Please note: + 1. If enable_key is set, the bypass method need to be + overridden to define the node behavior when disabled + 2. Some hot-key has been use for particular use. For example: + 'q', 'Q' and 27 are used for quit + Default: ``None`` + enable (bool): Default enable/disable status. Default: ``True``. + kpt_thr (float): The score threshold of valid keypoints. Default: 0.5 + resource_img_path (str, optional): The resource image path or url. + The image should be a pair of sunglasses with white background. + If not specified, the url of a default image will be used. See + ``SunglassesNode.default_resource_img_path``. Default: ``None`` + + Example:: + >>> cfg = dict( + ... type='SunglassesEffectNode', + ... name='sunglasses', + ... enable_key='s', + ... enable=False, + ... input_buffer='vis', + ... output_buffer='vis_sunglasses') + + >>> from mmpose.apis.webcam.nodes import NODES + >>> node = NODES.build(cfg) + """ + + # The image attributes to: + # "https://www.vecteezy.com/vector-art/1932353-summer-sunglasses- + # accessory-isolated-icon" by Vecteezy + default_resource_img_path = ( + 'https://user-images.githubusercontent.com/15977946/' + '170850839-acc59e26-c6b3-48c9-a9ec-87556edb99ed.jpg') + + def __init__(self, + name: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + kpt_thr: float = 0.5, + resource_img_path: Optional[str] = None): + + super().__init__( + name=name, + input_buffer=input_buffer, + output_buffer=output_buffer, + enable_key=enable_key, + enable=enable) + + if resource_img_path is None: + resource_img_path = self.default_resource_img_path + + self.resource_img = load_image_from_disk_or_url(resource_img_path) + self.kpt_thr = kpt_thr + + def draw(self, input_msg): + canvas = input_msg.get_image() + + objects = input_msg.get_objects(lambda x: 'keypoints' in x) + + for dataset_meta, group in groupby(objects, + lambda x: x['dataset_meta']): + left_eye_index, right_eye_index = get_eye_keypoint_ids( + dataset_meta) + canvas = self.apply_sunglasses_effect(canvas, group, + left_eye_index, + right_eye_index) + return canvas + + def apply_sunglasses_effect(self, canvas: np.ndarray, objects: List[Dict], + left_eye_index: int, + right_eye_index: int) -> np.ndarray: + """Apply sunglasses effect. + + Args: + canvas (np.ndarray): The image to apply the effect + objects (list[dict]): The object list with keypoints + - "keypoints" ([K,3]): keypoints in [x, y, score] + left_eye_index (int): Keypoint index of the left eye + right_eye_index (int): Keypoint index of the right eye + + Returns: + np.ndarray: Processed image + """ + + hm, wm = self.resource_img.shape[:2] + # anchor points in the sunglasses image + pts_src = np.array([[0.3 * wm, 0.3 * hm], [0.3 * wm, 0.7 * hm], + [0.7 * wm, 0.3 * hm], [0.7 * wm, 0.7 * hm]], + dtype=np.float32) + + for obj in objects: + kpts = obj['keypoints'] + kpt_scores = obj['keypoint_scores'] + + if kpt_scores[left_eye_index] < self.kpt_thr or kpt_scores[ + right_eye_index] < self.kpt_thr: + continue + + kpt_leye = kpts[left_eye_index, :2] + kpt_reye = kpts[right_eye_index, :2] + # orthogonal vector to the left-to-right eyes + vo = 0.5 * (kpt_reye - kpt_leye)[::-1] * [-1, 1] + + # anchor points in the image by eye positions + pts_tar = np.vstack( + [kpt_reye + vo, kpt_reye - vo, kpt_leye + vo, kpt_leye - vo]) + + h_mat, _ = cv2.findHomography(pts_src, pts_tar) + patch = cv2.warpPerspective( + self.resource_img, + h_mat, + dsize=(canvas.shape[1], canvas.shape[0]), + borderValue=(255, 255, 255)) + # mask the white background area in the patch with a threshold 200 + mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY) + mask = (mask < 200).astype(np.uint8) + canvas = cv2.copyTo(patch, mask, canvas) + + return canvas diff --git a/mmpose/apis/webcam/utils/__init__.py b/mmpose/apis/webcam/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2911bcd5bf451477d0f4266d89786028cae68293 --- /dev/null +++ b/mmpose/apis/webcam/utils/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .buffer import BufferManager +from .event import EventManager +from .image_capture import ImageCapture +from .message import FrameMessage, Message, VideoEndingMessage +from .misc import (copy_and_paste, expand_and_clamp, get_cached_file_path, + get_config_path, is_image_file, limit_max_fps, + load_image_from_disk_or_url, screen_matting) +from .pose import (get_eye_keypoint_ids, get_face_keypoint_ids, + get_hand_keypoint_ids, get_mouth_keypoint_ids, + get_wrist_keypoint_ids) + +__all__ = [ + 'BufferManager', 'EventManager', 'FrameMessage', 'Message', + 'limit_max_fps', 'VideoEndingMessage', 'load_image_from_disk_or_url', + 'get_cached_file_path', 'screen_matting', 'get_config_path', + 'expand_and_clamp', 'copy_and_paste', 'is_image_file', 'ImageCapture', + 'get_eye_keypoint_ids', 'get_face_keypoint_ids', 'get_wrist_keypoint_ids', + 'get_mouth_keypoint_ids', 'get_hand_keypoint_ids' +] diff --git a/mmpose/apis/webcam/utils/buffer.py b/mmpose/apis/webcam/utils/buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f8b9864ee0136173b161a352138c9fac8e879e --- /dev/null +++ b/mmpose/apis/webcam/utils/buffer.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import wraps +from queue import Queue +from typing import Any, Dict, List, Optional + +from mmengine import is_seq_of + +__all__ = ['BufferManager'] + + +def check_buffer_registered(exist=True): + """A function wrapper to check the buffer existence before it is being used + by the wrapped function. + + Args: + exist (bool): If set to ``True``, assert the buffer exists; if set to + ``False``, assert the buffer does not exist. Default: ``True`` + """ + + def wrapper(func): + + @wraps(func) + def wrapped(manager, name, *args, **kwargs): + if exist: + # Assert buffer exist + if name not in manager: + raise ValueError(f'Fail to call {func.__name__}: ' + f'buffer "{name}" is not registered.') + else: + # Assert buffer not exist + if name in manager: + raise ValueError(f'Fail to call {func.__name__}: ' + f'buffer "{name}" is already registered.') + return func(manager, name, *args, **kwargs) + + return wrapped + + return wrapper + + +class Buffer(Queue): + + def put_force(self, item: Any): + """Force to put an item into the buffer. + + If the buffer is already full, the earliest item in the buffer will be + remove to make room for the incoming item. + + Args: + item (any): The item to put into the buffer + """ + with self.mutex: + if self.maxsize > 0: + while self._qsize() >= self.maxsize: + _ = self._get() + self.unfinished_tasks -= 1 + + self._put(item) + self.unfinished_tasks += 1 + self.not_empty.notify() + + +class BufferManager(): + """A helper class to manage multiple buffers. + + Parameters: + buffer_type (type): The class to build buffer instances. Default: + :class:`mmpose.apis.webcam.utils.buffer.Buffer`. + buffers (dict, optional): Create :class:`BufferManager` from existing + buffers. Each item should a buffer name and the buffer. If not + given, an empty buffer manager will be create. Default: ``None`` + """ + + def __init__(self, + buffer_type: type = Buffer, + buffers: Optional[Dict] = None): + self.buffer_type = buffer_type + if buffers is None: + self._buffers = {} + else: + if is_seq_of(list(buffers.values()), buffer_type): + self._buffers = buffers.copy() + else: + raise ValueError('The values of buffers should be instance ' + f'of {buffer_type}') + + def __contains__(self, name): + return name in self._buffers + + @check_buffer_registered(False) + def register_buffer(self, name, maxsize: int = 0): + """Register a buffer. + + If the buffer already exists, an ValueError will be raised. + + Args: + name (any): The buffer name + maxsize (int): The capacity of the buffer. If set to 0, the + capacity is unlimited. Default: 0 + """ + self._buffers[name] = self.buffer_type(maxsize) + + @check_buffer_registered() + def put(self, name, item, block: bool = True, timeout: float = None): + """Put an item into specified buffer. + + Args: + name (any): The buffer name + item (any): The item to put into the buffer + block (bool): If set to ``True``, block if necessary util a free + slot is available in the target buffer. It blocks at most + ``timeout`` seconds and raises the ``Full`` exception. + Otherwise, put an item on the queue if a free slot is + immediately available, else raise the ``Full`` exception. + Default: ``True`` + timeout (float, optional): The most waiting time in seconds if + ``block`` is ``True``. Default: ``None`` + """ + self._buffers[name].put(item, block, timeout) + + @check_buffer_registered() + def put_force(self, name, item): + """Force to put an item into specified buffer. If the buffer was full, + the earliest item within the buffer will be popped out to make a free + slot. + + Args: + name (any): The buffer name + item (any): The item to put into the buffer + """ + self._buffers[name].put_force(item) + + @check_buffer_registered() + def get(self, name, block: bool = True, timeout: float = None) -> Any: + """Remove an return an item from the specified buffer. + + Args: + name (any): The buffer name + block (bool): If set to ``True``, block if necessary until an item + is available in the target buffer. It blocks at most + ``timeout`` seconds and raises the ``Empty`` exception. + Otherwise, return an item if one is immediately available, + else raise the ``Empty`` exception. Default: ``True`` + timeout (float, optional): The most waiting time in seconds if + ``block`` is ``True``. Default: ``None`` + + Returns: + any: The returned item. + """ + return self._buffers[name].get(block, timeout) + + @check_buffer_registered() + def is_empty(self, name) -> bool: + """Check if a buffer is empty. + + Args: + name (any): The buffer name + + Returns: + bool: Weather the buffer is empty. + """ + return self._buffers[name].empty() + + @check_buffer_registered() + def is_full(self, name): + """Check if a buffer is full. + + Args: + name (any): The buffer name + + Returns: + bool: Weather the buffer is full. + """ + return self._buffers[name].full() + + def get_sub_manager(self, buffer_names: List[str]) -> 'BufferManager': + """Return a :class:`BufferManager` instance that covers a subset of the + buffers in the parent. The is usually used to partially share the + buffers of the executor to the node. + + Args: + buffer_names (list): The list of buffers to create the sub manager + + Returns: + BufferManager: The created sub buffer manager. + """ + buffers = {name: self._buffers[name] for name in buffer_names} + return BufferManager(self.buffer_type, buffers) + + def get_info(self): + """Returns the information of all buffers in the manager. + + Returns: + dict[any, dict]: Each item is a buffer name and the information + dict of that buffer. + """ + buffer_info = {} + for name, buffer in self._buffers.items(): + buffer_info[name] = { + 'size': buffer.qsize(), + 'maxsize': buffer.maxsize + } + return buffer_info diff --git a/mmpose/apis/webcam/utils/event.py b/mmpose/apis/webcam/utils/event.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e88e1d8bdbecdeef0352cfd1ef00fd0940725a --- /dev/null +++ b/mmpose/apis/webcam/utils/event.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from collections import defaultdict +from contextlib import contextmanager +from threading import Event +from typing import Optional + +logger = logging.getLogger('Event') + + +class EventManager(): + """A helper class to manage events. + + :class:`EventManager` provides interfaces to register, set, clear and + check events by name. + """ + + def __init__(self): + self._events = defaultdict(Event) + + def register_event(self, event_name: str, is_keyboard: bool = False): + """Register an event. A event must be registered first before being + set, cleared or checked. + + Args: + event_name (str): The indicator of the event. The name should be + unique in one :class:`EventManager` instance + is_keyboard (bool): Specify weather it is a keyboard event. If so, + the ``event_name`` should be the key value, and the indicator + will be set as ``'_keyboard_{event_name}'``. Otherwise, the + ``event_name`` will be directly used as the indicator. + Default: ``False`` + """ + if is_keyboard: + event_name = self._get_keyboard_event_name(event_name) + self._events[event_name] = Event() + + def set(self, event_name: str, is_keyboard: bool = False): + """Set the internal flag of an event to ``True``. + + Args: + event_name (str): The indicator of the event + is_keyboard (bool): Specify weather it is a keyboard event. See + ``register_event()`` for details. Default: False + """ + if is_keyboard: + event_name = self._get_keyboard_event_name(event_name) + self._events[event_name].set() + logger.info(f'Event {event_name} is set.') + + def wait(self, + event_name: str = None, + is_keyboard: bool = False, + timeout: Optional[float] = None) -> bool: + """Block until the internal flag of an event is ``True``. + + Args: + event_name (str): The indicator of the event + is_keyboard (bool): Specify weather it is a keyboard event. See + ``register_event()`` for details. Default: False + timeout (float, optional): The optional maximum blocking time in + seconds. Default: ``None`` + + Returns: + bool: The internal event flag on exit. + """ + if is_keyboard: + event_name = self._get_keyboard_event_name(event_name) + return self._events[event_name].wait(timeout) + + def is_set(self, + event_name: str = None, + is_keyboard: Optional[bool] = False) -> bool: + """Check weather the internal flag of an event is ``True``. + + Args: + event_name (str): The indicator of the event + is_keyboard (bool): Specify weather it is a keyboard event. See + ``register_event()`` for details. Default: False + Returns: + bool: The internal event flag. + """ + if is_keyboard: + event_name = self._get_keyboard_event_name(event_name) + return self._events[event_name].is_set() + + def clear(self, + event_name: str = None, + is_keyboard: Optional[bool] = False): + """Reset the internal flag of en event to False. + + Args: + event_name (str): The indicator of the event + is_keyboard (bool): Specify weather it is a keyboard event. See + ``register_event()`` for details. Default: False + """ + if is_keyboard: + event_name = self._get_keyboard_event_name(event_name) + self._events[event_name].clear() + logger.info(f'Event {event_name} is cleared.') + + @staticmethod + def _get_keyboard_event_name(key): + """Get keyboard event name from the key value.""" + return f'_keyboard_{chr(key) if isinstance(key,int) else key}' + + @contextmanager + def wait_and_handle(self, + event_name: str = None, + is_keyboard: Optional[bool] = False): + """Context manager that blocks until an evenet is set ``True`` and then + goes into the context. + + The internal event flag will be reset ``False`` automatically before + entering the context. + + Args: + event_name (str): The indicator of the event + is_keyboard (bool): Specify weather it is a keyboard event. See + ``register_event()`` for details. Default: False + + Example:: + >>> from mmpose.apis.webcam.utils import EventManager + >>> manager = EventManager() + >>> manager.register_event('q', is_keybard=True) + + >>> # Once the keyboard event `q` is set, ``wait_and_handle`` + >>> # will reset the event and enter the context to invoke + >>> # ``foo()`` + >>> with manager.wait_and_handle('q', is_keybard=True): + ... foo() + """ + self.wait(event_name, is_keyboard) + try: + yield + finally: + self.clear(event_name, is_keyboard) diff --git a/mmpose/apis/webcam/utils/image_capture.py b/mmpose/apis/webcam/utils/image_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..fb28acff942d007b26345f5633746d6f948b9e70 --- /dev/null +++ b/mmpose/apis/webcam/utils/image_capture.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import cv2 +import numpy as np + +from .misc import load_image_from_disk_or_url + + +class ImageCapture: + """A mock-up of cv2.VideoCapture that always return a const image. + + Args: + image (str | ndarray): The image path or image data + """ + + def __init__(self, image: Union[str, np.ndarray]): + if isinstance(image, str): + self.image = load_image_from_disk_or_url(image) + else: + self.image = image + + def isOpened(self): + return (self.image is not None) + + def read(self): + return True, self.image.copy() + + def release(self): + pass + + def get(self, propId): + if propId == cv2.CAP_PROP_FRAME_WIDTH: + return self.image.shape[1] + elif propId == cv2.CAP_PROP_FRAME_HEIGHT: + return self.image.shape[0] + elif propId == cv2.CAP_PROP_FPS: + return np.nan + else: + raise NotImplementedError() diff --git a/mmpose/apis/webcam/utils/message.py b/mmpose/apis/webcam/utils/message.py new file mode 100644 index 0000000000000000000000000000000000000000..8961ea39c29e88ebe7a42bd250656f417859e7c8 --- /dev/null +++ b/mmpose/apis/webcam/utils/message.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +import uuid +import warnings +from typing import Callable, Dict, List, Optional + +import numpy as np + +Filter = Callable[[Dict], bool] + + +class Message(): + """Message base class. + + All message class should inherit this class. The basic use of a Message + instance is to carray a piece of text message (self.msg) and a dict that + stores structured data (self.data), e.g. frame image, model prediction, + et al. + + A message may also hold route information, which is composed of + information of all nodes the message has passed through. + + Parameters: + msg (str): The text message. + data (dict, optional): The structured data. + """ + + def __init__(self, msg: str = '', data: Optional[Dict] = None): + self.msg = msg + self.data = data if data else {} + self.route_info = [] + self.timestamp = time.time() + self.id = uuid.uuid1() + + def update_route_info(self, + node=None, + node_name: Optional[str] = None, + node_type: Optional[str] = None, + info: Optional[Dict] = None): + """Append new node information to the route information. + + Args: + node (Node, optional): An instance of Node that provides basic + information like the node name and type. Default: ``None``. + node_name (str, optional): The node name. If node is given, + node_name will be ignored. Default: ``None``. + node_type (str, optional): The class name of the node. If node + is given, node_type will be ignored. Default: ``None``. + info (dict, optional): The node information, which is usually + given by node.get_node_info(). Default: ``None``. + """ + if node is not None: + if node_name is not None or node_type is not None: + warnings.warn( + '`node_name` and `node_type` will be overridden if node ' + 'is provided.') + node_name = node.name + node_type = node.__class__.__name__ + + node_info = {'node': node_name, 'node_type': node_type, 'info': info} + self.route_info.append(node_info) + + def set_route_info(self, route_info: List[Dict]): + """Directly set the entire route information. + + Args: + route_info (list): route information to set to the message. + """ + self.route_info = route_info + + def merge_route_info(self, route_info: List[Dict]): + """Merge the given route information into the original one of the + message. This is used for combining route information from multiple + messages. The node information in the route will be reordered according + to their timestamps. + + Args: + route_info (list): route information to merge. + """ + self.route_info += route_info + self.route_info.sort(key=lambda x: x.get('timestamp', np.inf)) + + def get_route_info(self) -> List: + return self.route_info.copy() + + +class VideoEndingMessage(Message): + """The special message to indicate the ending of the input video.""" + + +class FrameMessage(Message): + """The message to store information of a video frame.""" + + def __init__(self, img): + super().__init__(data=dict(image=img, objects={}, model_cfgs={})) + + def get_image(self) -> np.ndarray: + """Get the frame image. + + Returns: + np.ndarray: The frame image. + """ + return self.data.get('image', None) + + def set_image(self, img): + """Set the frame image to the message. + + Args: + img (np.ndarray): The frame image. + """ + self.data['image'] = img + + def set_objects(self, objects: List[Dict]): + """Set the object information. The old object information will be + cleared. + + Args: + objects (list[dict]): A list of object information + + See also :func:`update_objects`. + """ + self.data['objects'] = {} + self.update_objects(objects) + + def update_objects(self, objects: List[Dict]): + """Update object information. + + Each object will be assigned an unique ID if it does not has one. If + an object's ID already exists in ``self.data['objects']``, the object + information will be updated; otherwise it will be added as a new + object. + + Args: + objects (list[dict]): A list of object information + """ + for obj in objects: + if '_id_' in obj: + # get the object id if it exists + obj_id = obj['_id_'] + else: + # otherwise assign a new object id + obj_id = uuid.uuid1() + obj['_id_'] = obj_id + self.data['objects'][obj_id] = obj + + def get_objects(self, obj_filter: Optional[Filter] = None) -> List[Dict]: + """Get object information from the frame data. + + Default to return all objects in the frame data. Optionally, filters + can be set to retrieve objects with specific keys and values. The + filters are represented as a dict. Each key in the filters specifies a + required key of the object. Each value in the filters is a tuple that + enumerate the required values of the corresponding key in the object. + + Args: + obj_filter (callable, optional): A filter function that returns a + bool value from a object (dict). If provided, only objects + that return True will be retrieved. Otherwise all objects will + be retrieved. Default: ``None``. + + Returns: + list[dict]: A list of object information. + + + Example:: + >>> objects = [ + ... {'_id_': 2, 'label': 'dog'} + ... {'_id_': 1, 'label': 'cat'}, + ... ] + >>> frame = FrameMessage(img) + >>> frame.set_objects(objects) + >>> frame.get_objects() + [ + {'_id_': 1, 'label': 'cat'}, + {'_id_': 2, 'label': 'dog'} + ] + >>> frame.get_objects(obj_filter=lambda x:x['label'] == 'cat') + [{'_id_': 1, 'label': 'cat'}] + """ + + objects = [ + obj.copy() + for obj in filter(obj_filter, self.data['objects'].values()) + ] + + return objects diff --git a/mmpose/apis/webcam/utils/misc.py b/mmpose/apis/webcam/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6f5417aeeceb40fbe92ddb0aa8e27a9a0faadc --- /dev/null +++ b/mmpose/apis/webcam/utils/misc.py @@ -0,0 +1,367 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os +import os.path as osp +import sys +import time +from contextlib import contextmanager +from typing import List, Optional, Tuple +from urllib.parse import urlparse +from urllib.request import urlopen + +import cv2 +import numpy as np +from mmengine import mkdir_or_exist +from torch.hub import HASH_REGEX, download_url_to_file + + +@contextmanager +def limit_max_fps(fps: float): + """A context manager to limit maximum frequence of entering the context. + + Args: + fps (float): The maximum frequence of entering the context + + Example:: + >>> from mmpose.apis.webcam.utils import limit_max_fps + >>> import cv2 + + >>> while True: + ... with limit_max_fps(20): + ... cv2.imshow(img) # display image at most 20 fps + """ + t_start = time.time() + try: + yield + finally: + t_end = time.time() + if fps is not None: + t_sleep = 1.0 / fps - t_end + t_start + if t_sleep > 0: + time.sleep(t_sleep) + + +def _is_url(filename: str) -> bool: + """Check if the file is a url link. + + Args: + filename (str): the file name or url link + + Returns: + bool: is url or not. + """ + prefixes = ['http://', 'https://'] + for p in prefixes: + if filename.startswith(p): + return True + return False + + +def load_image_from_disk_or_url(filename: str, + readFlag: int = cv2.IMREAD_COLOR + ) -> np.ndarray: + """Load an image file, from disk or url. + + Args: + filename (str): file name on the disk or url link + readFlag (int): readFlag for imdecode. Default: cv2.IMREAD_COLOR + + Returns: + np.ndarray: A loaded image + """ + if _is_url(filename): + # download the image, convert it to a NumPy array, and then read + # it into OpenCV format + resp = urlopen(filename) + image = np.asarray(bytearray(resp.read()), dtype='uint8') + image = cv2.imdecode(image, readFlag) + return image + else: + image = cv2.imread(filename, readFlag) + return image + + +def get_cached_file_path(url: str, + save_dir: str, + progress: bool = True, + check_hash: bool = False, + file_name: Optional[str] = None) -> str: + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically decompressed + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + url (str): URL of the object to download + save_dir (str): directory in which to save the object + progress (bool): whether or not to display a progress bar + to stderr. Default: ``True`` + check_hash(bool): If True, the filename part of the URL + should follow the naming convention ``filename-.ext`` + where ```` is the first eight or more digits of the + SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: ``False`` + file_name (str, optional): name for the downloaded file. Filename + from ``url`` will be used if not set. Default: ``None``. + + Returns: + str: The path to the cached file. + """ + + mkdir_or_exist(save_dir) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(save_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def screen_matting(img: np.ndarray, + color_low: Optional[Tuple] = None, + color_high: Optional[Tuple] = None, + color: Optional[str] = None) -> np.ndarray: + """Get screen matting mask. + + Args: + img (np.ndarray): Image data. + color_low (tuple): Lower limit (b, g, r). + color_high (tuple): Higher limit (b, g, r). + color (str): Support colors include: + + - 'green' or 'g' + - 'blue' or 'b' + - 'black' or 'k' + - 'white' or 'w' + + Returns: + np.ndarray: A mask with the same shape of the input image. The value + is 0 at the pixels in the matting color range, and 1 everywhere else. + """ + + if color_high is None or color_low is None: + if color is not None: + if color.lower() == 'g' or color.lower() == 'green': + color_low = (0, 200, 0) + color_high = (60, 255, 60) + elif color.lower() == 'b' or color.lower() == 'blue': + color_low = (230, 0, 0) + color_high = (255, 40, 40) + elif color.lower() == 'k' or color.lower() == 'black': + color_low = (0, 0, 0) + color_high = (40, 40, 40) + elif color.lower() == 'w' or color.lower() == 'white': + color_low = (230, 230, 230) + color_high = (255, 255, 255) + else: + raise NotImplementedError(f'Not supported color: {color}.') + else: + raise ValueError( + 'color or color_high | color_low should be given.') + + mask = cv2.inRange(img, np.array(color_low), np.array(color_high)) == 0 + + return mask.astype(np.uint8) + + +def expand_and_clamp(box: List, im_shape: Tuple, scale: float = 1.25) -> List: + """Expand the bbox and clip it to fit the image shape. + + Args: + box (list): x1, y1, x2, y2 + im_shape (tuple): image shape (h, w, c) + scale (float): expand ratio + + Returns: + list: x1, y1, x2, y2 + """ + + x1, y1, x2, y2 = box[:4] + w = x2 - x1 + h = y2 - y1 + deta_w = w * (scale - 1) / 2 + deta_h = h * (scale - 1) / 2 + + x1, y1, x2, y2 = x1 - deta_w, y1 - deta_h, x2 + deta_w, y2 + deta_h + + img_h, img_w = im_shape[:2] + + x1 = min(max(0, int(x1)), img_w - 1) + y1 = min(max(0, int(y1)), img_h - 1) + x2 = min(max(0, int(x2)), img_w - 1) + y2 = min(max(0, int(y2)), img_h - 1) + + return [x1, y1, x2, y2] + + +def _find_bbox(mask): + """Find the bounding box for the mask. + + Args: + mask (ndarray): Mask. + + Returns: + list(4, ): Returned box (x1, y1, x2, y2). + """ + mask_shape = mask.shape + if len(mask_shape) == 3: + assert mask_shape[-1] == 1, 'the channel of the mask should be 1.' + elif len(mask_shape) == 2: + pass + else: + NotImplementedError() + + h, w = mask_shape[:2] + mask_w = mask.sum(0) + mask_h = mask.sum(1) + + left = 0 + right = w - 1 + up = 0 + down = h - 1 + + for i in range(w): + if mask_w[i] > 0: + break + left += 1 + + for i in range(w - 1, left, -1): + if mask_w[i] > 0: + break + right -= 1 + + for i in range(h): + if mask_h[i] > 0: + break + up += 1 + + for i in range(h - 1, up, -1): + if mask_h[i] > 0: + break + down -= 1 + + return [left, up, right, down] + + +def copy_and_paste( + img: np.ndarray, + background_img: np.ndarray, + mask: np.ndarray, + bbox: Optional[List] = None, + effect_region: Tuple = (0.2, 0.2, 0.8, 0.8), + min_size: Tuple = (20, 20) +) -> np.ndarray: + """Copy the image region and paste to the background. + + Args: + img (np.ndarray): Image data. + background_img (np.ndarray): Background image data. + mask (ndarray): instance segmentation result. + bbox (list, optional): instance bbox in (x1, y1, x2, y2). If not + given, the bbox will be obtained by ``_find_bbox()``. Default: + ``None`` + effect_region (tuple): The region to apply mask, the coordinates + are normalized (x1, y1, x2, y2). Default: (0.2, 0.2, 0.8, 0.8) + min_size (tuple): The minimum region size (w, h) in pixels. + Default: (20, 20) + + Returns: + np.ndarray: The background with pasted image region. + """ + background_img = background_img.copy() + background_h, background_w = background_img.shape[:2] + region_h = (effect_region[3] - effect_region[1]) * background_h + region_w = (effect_region[2] - effect_region[0]) * background_w + region_aspect_ratio = region_w / region_h + + if bbox is None: + bbox = _find_bbox(mask) + instance_w = bbox[2] - bbox[0] + instance_h = bbox[3] - bbox[1] + + if instance_w > min_size[0] and instance_h > min_size[1]: + aspect_ratio = instance_w / instance_h + if region_aspect_ratio > aspect_ratio: + resize_rate = region_h / instance_h + else: + resize_rate = region_w / instance_w + + mask_inst = mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] + img_inst = img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] + img_inst = cv2.resize( + img_inst.astype('float32'), + (int(resize_rate * instance_w), int(resize_rate * instance_h))) + img_inst = img_inst.astype(background_img.dtype) + mask_inst = cv2.resize( + mask_inst.astype('float32'), + (int(resize_rate * instance_w), int(resize_rate * instance_h)), + interpolation=cv2.INTER_NEAREST) + + mask_ids = list(np.where(mask_inst == 1)) + mask_ids[1] += int(effect_region[0] * background_w) + mask_ids[0] += int(effect_region[1] * background_h) + + background_img[tuple(mask_ids)] = img_inst[np.where(mask_inst == 1)] + + return background_img + + +def is_image_file(path: str) -> bool: + """Check if a path is an image file by its extension. + + Args: + path (str): The image path. + + Returns: + bool: Weather the path is an image file. + """ + if isinstance(path, str): + if path.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp')): + return True + return False + + +def get_config_path(path: str, module_name: str): + """Get config path from an OpenMMLab codebase. + + If the path is an existing file, it will be directly returned. If the file + doesn't exist, it will be searched in the 'configs' folder of the + specified module. + + Args: + path (str): the path of the config file + module_name (str): The module name of an OpenMMLab codebase + + Returns: + str: The config file path. + + Example:: + >>> path = 'configs/_base_/filters/one_euro.py' + >>> get_config_path(path, 'mmpose') + '/home/mmpose/configs/_base_/filters/one_euro.py' + """ + + if osp.isfile(path): + return path + + module = importlib.import_module(module_name) + module_dir = osp.dirname(module.__file__) + path_in_module = osp.join(module_dir, '.mim', path) + + if not osp.isfile(path_in_module): + raise FileNotFoundError(f'Can not find the config file "{path}"') + + return path_in_module diff --git a/mmpose/apis/webcam/utils/pose.py b/mmpose/apis/webcam/utils/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff32f9e169ad4070f7ffb8d1f2664dfe50b8013 --- /dev/null +++ b/mmpose/apis/webcam/utils/pose.py @@ -0,0 +1,181 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + + +def get_eye_keypoint_ids(dataset_meta: Dict) -> Tuple[int, int]: + """A helper function to get the keypoint indices of left and right eyes + from the dataset meta information. + + Args: + dataset_meta (dict): dataset meta information. + + Returns: + tuple[int, int]: The keypoint indices of left eye and right eye. + """ + left_eye_idx = None + right_eye_idx = None + + # try obtaining eye point ids from dataset_meta + keypoint_name2id = dataset_meta.get('keypoint_name2id', {}) + left_eye_idx = keypoint_name2id.get('left_eye', None) + right_eye_idx = keypoint_name2id.get('right_eye', None) + + if left_eye_idx is None or right_eye_idx is None: + # Fall back to hard coded keypoint id + dataset_name = dataset_meta.get('dataset_name', 'unknown dataset') + if dataset_name in {'coco', 'coco_wholebody'}: + left_eye_idx = 1 + right_eye_idx = 2 + elif dataset_name in {'animalpose', 'ap10k'}: + left_eye_idx = 0 + right_eye_idx = 1 + else: + raise ValueError('Can not determine the eye keypoint id of ' + f'{dataset_name}') + + return left_eye_idx, right_eye_idx + + +def get_face_keypoint_ids(dataset_meta: Dict) -> List: + """A helper function to get the keypoint indices of the face from the + dataset meta information. + + Args: + dataset_meta (dict): dataset meta information. + + Returns: + list[int]: face keypoint indices. The length depends on the dataset. + """ + face_indices = [] + + # try obtaining nose point ids from dataset_meta + keypoint_name2id = dataset_meta.get('keypoint_name2id', {}) + for id in range(68): + face_indices.append(keypoint_name2id.get(f'face-{id}', None)) + + if None in face_indices: + # Fall back to hard coded keypoint id + dataset_name = dataset_meta.get('dataset_name', 'unknown dataset') + if dataset_name in {'coco_wholebody'}: + face_indices = list(range(23, 91)) + else: + raise ValueError('Can not determine the face id of ' + f'{dataset_name}') + + return face_indices + + +def get_wrist_keypoint_ids(dataset_meta: Dict) -> Tuple[int, int]: + """A helper function to get the keypoint indices of left and right wrists + from the dataset meta information. + + Args: + dataset_meta (dict): dataset meta information. + Returns: + tuple[int, int]: The keypoint indices of left and right wrists. + """ + + # try obtaining wrist point ids from dataset_meta + keypoint_name2id = dataset_meta.get('keypoint_name2id', {}) + left_wrist_idx = keypoint_name2id.get('left_wrist', None) + right_wrist_idx = keypoint_name2id.get('right_wrist', None) + + if left_wrist_idx is None or right_wrist_idx is None: + # Fall back to hard coded keypoint id + dataset_name = dataset_meta.get('dataset_name', 'unknown dataset') + if dataset_name in {'coco', 'coco_wholebody'}: + left_wrist_idx = 9 + right_wrist_idx = 10 + elif dataset_name == 'animalpose': + left_wrist_idx = 16 + right_wrist_idx = 17 + elif dataset_name == 'ap10k': + left_wrist_idx = 7 + right_wrist_idx = 10 + else: + raise ValueError('Can not determine the eye keypoint id of ' + f'{dataset_name}') + + return left_wrist_idx, right_wrist_idx + + +def get_mouth_keypoint_ids(dataset_meta: Dict) -> int: + """A helper function to get the mouth keypoint index from the dataset meta + information. + + Args: + dataset_meta (dict): dataset meta information. + Returns: + int: The mouth keypoint index + """ + # try obtaining mouth point ids from dataset_info + keypoint_name2id = dataset_meta.get('keypoint_name2id', {}) + mouth_index = keypoint_name2id.get('face-62', None) + + if mouth_index is None: + # Fall back to hard coded keypoint id + dataset_name = dataset_meta.get('dataset_name', 'unknown dataset') + if dataset_name == 'coco_wholebody': + mouth_index = 85 + else: + raise ValueError('Can not determine the eye keypoint id of ' + f'{dataset_name}') + + return mouth_index + + +def get_hand_keypoint_ids(dataset_meta: Dict) -> List[int]: + """A helper function to get the keypoint indices of left and right hand + from the dataset meta information. + + Args: + dataset_meta (dict): dataset meta information. + Returns: + list[int]: hand keypoint indices. The length depends on the dataset. + """ + # try obtaining hand keypoint ids from dataset_meta + keypoint_name2id = dataset_meta.get('keypoint_name2id', {}) + hand_indices = [] + hand_indices.append(keypoint_name2id.get('left_hand_root', None)) + + for id in range(1, 5): + hand_indices.append(keypoint_name2id.get(f'left_thumb{id}', None)) + for id in range(1, 5): + hand_indices.append(keypoint_name2id.get(f'left_forefinger{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'left_middle_finger{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'left_ring_finger{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'left_pinky_finger{id}', None)) + + hand_indices.append(keypoint_name2id.get('right_hand_root', None)) + + for id in range(1, 5): + hand_indices.append(keypoint_name2id.get(f'right_thumb{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'right_forefinger{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'right_middle_finger{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'right_ring_finger{id}', None)) + for id in range(1, 5): + hand_indices.append( + keypoint_name2id.get(f'right_pinky_finger{id}', None)) + + if None in hand_indices: + # Fall back to hard coded keypoint id + dataset_name = dataset_meta.get('dataset_name', 'unknown dataset') + if dataset_name in {'coco_wholebody'}: + hand_indices = list(range(91, 133)) + else: + raise ValueError('Can not determine the hand id of ' + f'{dataset_name}') + + return hand_indices diff --git a/mmpose/apis/webcam/webcam_executor.py b/mmpose/apis/webcam/webcam_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..f39aa4b84710c69064fe09e8cf0a3e4db8b006a2 --- /dev/null +++ b/mmpose/apis/webcam/webcam_executor.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import sys +import time +import warnings +from threading import Thread +from typing import Dict, List, Optional, Tuple, Union + +import cv2 + +from .nodes import NODES +from .utils import (BufferManager, EventManager, FrameMessage, ImageCapture, + VideoEndingMessage, is_image_file, limit_max_fps) + +try: + from contextlib import nullcontext +except ImportError: + # compatible with python3.6 + from contextlib import contextmanager + + @contextmanager + def nullcontext(enter_result=None): + yield enter_result + + +DEFAULT_FRAME_BUFFER_SIZE = 1 +DEFAULT_INPUT_BUFFER_SIZE = 1 +DEFAULT_DISPLAY_BUFFER_SIZE = 0 +DEFAULT_USER_BUFFER_SIZE = 1 + +logger = logging.getLogger('Executor') + + +class WebcamExecutor(): + """The interface to build and execute webcam applications from configs. + + Parameters: + nodes (list[dict]): Node configs. See :class:`webcam.nodes.Node` for + details + name (str): Executor name. Default: 'MMPose Webcam App'. + camera_id (int | str): The camera ID (usually the ID of the default + camera is 0). Alternatively a file path or a URL can be given + to load from a video or image file. + camera_frame_shape (tuple, optional): Set the frame shape of the + camera in (width, height). If not given, the default frame shape + will be used. This argument is only valid when using a camera + as the input source. Default: ``None`` + camera_max_fps (int): Video reading maximum FPS. Default: 30 + buffer_sizes (dict, optional): A dict to specify buffer sizes. The + key is the buffer name and the value is the buffer size. + Default: ``None`` + + Example:: + >>> cfg = dict( + >>> name='Test Webcam', + >>> camera_id=0, + >>> camera_max_fps=30, + >>> nodes=[ + >>> dict( + >>> type='MonitorNode', + >>> name='monitor', + >>> enable_key='m', + >>> enable=False, + >>> input_buffer='_frame_', + >>> output_buffer='display'), + >>> dict( + >>> type='RecorderNode', + >>> name='recorder', + >>> out_video_file='webcam_output.mp4', + >>> input_buffer='display', + >>> output_buffer='_display_') + >>> ]) + + >>> executor = WebcamExecutor(**cfg) + """ + + def __init__(self, + nodes: List[Dict], + name: str = 'MMPose Webcam App', + camera_id: Union[int, str] = 0, + camera_max_fps: int = 30, + camera_frame_shape: Optional[Tuple[int, int]] = None, + synchronous: bool = False, + buffer_sizes: Optional[Dict[str, int]] = None): + + # Basic parameters + self.name = name + self.camera_id = camera_id + self.camera_max_fps = camera_max_fps + self.camera_frame_shape = camera_frame_shape + self.synchronous = synchronous + + # self.buffer_manager manages data flow between executor and nodes + self.buffer_manager = BufferManager() + # self.event_manager manages event-based asynchronous communication + self.event_manager = EventManager() + # self.node_list holds all node instance + self.node_list = [] + # self.vcap is used to read camera frames. It will be built when the + # executor starts running + self.vcap = None + + # Register executor events + self.event_manager.register_event('_exit_', is_keyboard=False) + if self.synchronous: + self.event_manager.register_event('_idle_', is_keyboard=False) + + # Register nodes + if not nodes: + raise ValueError('No node is registered to the executor.') + + # Register default buffers + if buffer_sizes is None: + buffer_sizes = {} + # _frame_ buffer + frame_buffer_size = buffer_sizes.get('_frame_', + DEFAULT_FRAME_BUFFER_SIZE) + self.buffer_manager.register_buffer('_frame_', frame_buffer_size) + # _input_ buffer + input_buffer_size = buffer_sizes.get('_input_', + DEFAULT_INPUT_BUFFER_SIZE) + self.buffer_manager.register_buffer('_input_', input_buffer_size) + # _display_ buffer + display_buffer_size = buffer_sizes.get('_display_', + DEFAULT_DISPLAY_BUFFER_SIZE) + self.buffer_manager.register_buffer('_display_', display_buffer_size) + + # Build all nodes: + for node_cfg in nodes: + logger.info(f'Create node: {node_cfg.name}({node_cfg.type})') + node = NODES.build(node_cfg) + + # Register node + self.node_list.append(node) + + # Register buffers + for buffer_info in node.registered_buffers: + buffer_name = buffer_info.buffer_name + if buffer_name in self.buffer_manager: + continue + buffer_size = buffer_sizes.get(buffer_name, + DEFAULT_USER_BUFFER_SIZE) + self.buffer_manager.register_buffer(buffer_name, buffer_size) + logger.info( + f'Register user buffer: {buffer_name}({buffer_size})') + + # Register events + for event_info in node.registered_events: + self.event_manager.register_event( + event_name=event_info.event_name, + is_keyboard=event_info.is_keyboard) + logger.info(f'Register event: {event_info.event_name}') + + # Set executor for nodes + # This step is performed after node building when the executor has + # create full buffer/event managers and can + for node in self.node_list: + logger.info(f'Set executor for node: {node.name})') + node.set_executor(self) + + def _read_camera(self): + """Read video frames from the caemra (or the source video/image) and + put them into input buffers.""" + + camera_id = self.camera_id + fps = self.camera_max_fps + + # Build video capture + if is_image_file(camera_id): + self.vcap = ImageCapture(camera_id) + else: + self.vcap = cv2.VideoCapture(camera_id) + if self.camera_frame_shape is not None: + width, height = self.camera_frame_shape + self.vcap.set(cv2.CAP_PROP_FRAME_WIDTH, width) + self.vcap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) + + if not self.vcap.isOpened(): + warnings.warn(f'Cannot open camera (ID={camera_id})') + sys.exit() + + # Read video frames in a loop + first_frame = True + while not self.event_manager.is_set('_exit_'): + if self.synchronous: + if first_frame: + cm = nullcontext() + else: + # Read a new frame until the last frame has been processed + cm = self.event_manager.wait_and_handle('_idle_') + else: + # Read frames with a maximum FPS + cm = limit_max_fps(fps) + + first_frame = False + + with cm: + # Read a frame + ret_val, frame = self.vcap.read() + if ret_val: + # Put frame message (for display) into buffer `_frame_` + frame_msg = FrameMessage(frame) + self.buffer_manager.put('_frame_', frame_msg) + + # Put input message (for model inference or other use) + # into buffer `_input_` + input_msg = FrameMessage(frame.copy()) + input_msg.update_route_info( + node_name='Camera Info', + node_type='none', + info=self._get_camera_info()) + self.buffer_manager.put_force('_input_', input_msg) + logger.info('Read one frame.') + else: + logger.info('Reached the end of the video.') + # Put a video ending signal + self.buffer_manager.put_force('_frame_', + VideoEndingMessage()) + self.buffer_manager.put_force('_input_', + VideoEndingMessage()) + # Wait for `_exit_` event util a timeout occurs + if not self.event_manager.wait('_exit_', timeout=5.0): + break + + self.vcap.release() + + def _display(self): + """Receive processed frames from the output buffer and display on + screen.""" + + output_msg = None + + while not self.event_manager.is_set('_exit_'): + while self.buffer_manager.is_empty('_display_'): + time.sleep(0.001) + + # Set _idle_ to allow reading next frame + if self.synchronous: + self.event_manager.set('_idle_') + + # acquire output from buffer + output_msg = self.buffer_manager.get('_display_') + + # None indicates input stream ends + if isinstance(output_msg, VideoEndingMessage): + self.event_manager.set('_exit_') + break + + img = output_msg.get_image() + + # show in a window + cv2.imshow(self.name, img) + + # handle keyboard input + key = cv2.waitKey(1) + if key != -1: + self._on_keyboard_input(key) + + cv2.destroyAllWindows() + + # Avoid dead lock + if self.synchronous: + self.event_manager.set('_idle_') + + def _on_keyboard_input(self, key): + """Handle the keyboard input. + + The key 'Q' and `ESC` will trigger an '_exit_' event, which will be + responded by all nodes and the executor itself to exit. Other keys will + trigger keyboard event to be responded by the nodes which has + registered corresponding event. See :class:`webcam.utils.EventManager` + for details. + """ + + if key in (27, ord('q'), ord('Q')): + logger.info(f'Exit event captured: {key}') + self.event_manager.set('_exit_') + else: + logger.info(f'Keyboard event captured: {key}') + self.event_manager.set(key, is_keyboard=True) + + def _get_camera_info(self): + """Return the camera information in a dict.""" + + frame_width = self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH) + frame_height = self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT) + frame_rate = self.vcap.get(cv2.CAP_PROP_FPS) + + cam_info = { + 'Camera ID': self.camera_id, + 'Camera resolution': f'{frame_width}x{frame_height}', + 'Camera FPS': frame_rate, + } + + return cam_info + + def run(self): + """Start the executor. + + This method starts all nodes as well as video I/O in separate threads. + """ + + try: + # Start node threads + non_daemon_nodes = [] + for node in self.node_list: + node.start() + if not node.daemon: + non_daemon_nodes.append(node) + + # Create a thread to read video frames + t_read = Thread(target=self._read_camera, args=()) + t_read.start() + + # Run display in the main thread + self._display() + logger.info('Display has stopped.') + + # joint non-daemon nodes and executor threads + logger.info('Camera reading is about to join.') + t_read.join() + + for node in non_daemon_nodes: + logger.info(f'Node {node.name} is about to join.') + node.join() + logger.info('All nodes jointed successfully.') + + except KeyboardInterrupt: + pass diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a88ebac70169321c248a70346b1db75c7908e58b --- /dev/null +++ b/mmpose/codecs/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .associative_embedding import AssociativeEmbedding +from .decoupled_heatmap import DecoupledHeatmap +from .integral_regression_label import IntegralRegressionLabel +from .megvii_heatmap import MegviiHeatmap +from .msra_heatmap import MSRAHeatmap +from .regression_label import RegressionLabel +from .simcc_label import SimCCLabel +from .spr import SPR +from .udp_heatmap import UDPHeatmap + +__all__ = [ + 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', + 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', + 'DecoupledHeatmap' +] diff --git a/mmpose/codecs/__pycache__/__init__.cpython-38.pyc b/mmpose/codecs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae501df82620881a4b9e83192275219f68e9762d Binary files /dev/null and b/mmpose/codecs/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/associative_embedding.cpython-38.pyc b/mmpose/codecs/__pycache__/associative_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e407ecb8baf13482acb72c057037f24eee344c34 Binary files /dev/null and b/mmpose/codecs/__pycache__/associative_embedding.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/base.cpython-38.pyc b/mmpose/codecs/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef07cf8a623136b02c2175bb4e01d33286dc98f6 Binary files /dev/null and b/mmpose/codecs/__pycache__/base.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/decoupled_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/decoupled_heatmap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094eb04b5366ee46d1d113f276cd89abee8c877c Binary files /dev/null and b/mmpose/codecs/__pycache__/decoupled_heatmap.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/integral_regression_label.cpython-38.pyc b/mmpose/codecs/__pycache__/integral_regression_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5d7c364b100dc30ccd0aad366f8868a6c97750 Binary files /dev/null and b/mmpose/codecs/__pycache__/integral_regression_label.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/megvii_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/megvii_heatmap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3128624daa262bd0a8ac6acad71d969b0ceb4978 Binary files /dev/null and b/mmpose/codecs/__pycache__/megvii_heatmap.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/msra_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/msra_heatmap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc63009b8f820cd14f0cdb09f5325b8d5b0da309 Binary files /dev/null and b/mmpose/codecs/__pycache__/msra_heatmap.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/regression_label.cpython-38.pyc b/mmpose/codecs/__pycache__/regression_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab1cd72c23e64fbff523ecb6491de1c6ebed1800 Binary files /dev/null and b/mmpose/codecs/__pycache__/regression_label.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/simcc_label.cpython-38.pyc b/mmpose/codecs/__pycache__/simcc_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5de9cb61615685f4f32d944eed93a5219c87a1a Binary files /dev/null and b/mmpose/codecs/__pycache__/simcc_label.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/spr.cpython-38.pyc b/mmpose/codecs/__pycache__/spr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec5689cb7da25c6d34dafee423500d1f5da486f Binary files /dev/null and b/mmpose/codecs/__pycache__/spr.cpython-38.pyc differ diff --git a/mmpose/codecs/__pycache__/udp_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/udp_heatmap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbc7726fc0f90d16bade148624b9cbdad70f0cb0 Binary files /dev/null and b/mmpose/codecs/__pycache__/udp_heatmap.cpython-38.pyc differ diff --git a/mmpose/codecs/associative_embedding.py b/mmpose/codecs/associative_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7e080f1657d17deabf6e44bea275432216d621e5 --- /dev/null +++ b/mmpose/codecs/associative_embedding.py @@ -0,0 +1,512 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import namedtuple +from itertools import product +from typing import Any, List, Optional, Tuple + +import numpy as np +import torch +from munkres import Munkres +from torch import Tensor + +from mmpose.registry import KEYPOINT_CODECS +from mmpose.utils.tensor_utils import to_numpy +from .base import BaseKeypointCodec +from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps, + generate_udp_gaussian_heatmaps, refine_keypoints, + refine_keypoints_dark_udp) + + +def _group_keypoints_by_tags(vals: np.ndarray, + tags: np.ndarray, + locs: np.ndarray, + keypoint_order: List[int], + val_thr: float, + tag_thr: float = 1.0, + max_groups: Optional[int] = None) -> np.ndarray: + """Group the keypoints by tags using Munkres algorithm. + + Note: + + - keypoint number: K + - candidate number: M + - tag dimenssion: L + - coordinate dimension: D + - group number: G + + Args: + vals (np.ndarray): The heatmap response values of keypoints in shape + (K, M) + tags (np.ndarray): The tags of the keypoint candidates in shape + (K, M, L) + locs (np.ndarray): The locations of the keypoint candidates in shape + (K, M, D) + keypoint_order (List[int]): The grouping order of the keypoints. + The groupping usually starts from a keypoints around the head and + torso, and gruadually moves out to the limbs + val_thr (float): The threshold of the keypoint response value + tag_thr (float): The maximum allowed tag distance when matching a + keypoint to a group. A keypoint with larger tag distance to any + of the existing groups will initializes a new group + max_groups (int, optional): The maximum group number. ``None`` means + no limitation. Defaults to ``None`` + + Returns: + np.ndarray: grouped keypoints in shape (G, K, D+1), where the last + dimenssion is the concatenated keypoint coordinates and scores. + """ + K, M, D = locs.shape + assert vals.shape == tags.shape[:2] == (K, M) + assert len(keypoint_order) == K + + # Build Munkres instance + munkres = Munkres() + + # Build a group pool, each group contains the keypoints of an instance + groups = [] + + Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list']) + + def _init_group(): + """Initialize a group, which is composed of the keypoints, keypoint + scores and the tag of each keypoint.""" + _group = Group( + kpts=np.zeros((K, D), dtype=np.float32), + scores=np.zeros(K, dtype=np.float32), + tag_list=[]) + return _group + + for i in keypoint_order: + # Get all valid candidate of the i-th keypoints + valid = vals[i] > val_thr + if not valid.any(): + continue + + tags_i = tags[i, valid] # (M', L) + vals_i = vals[i, valid] # (M',) + locs_i = locs[i, valid] # (M', D) + + if len(groups) == 0: # Initialize the group pool + for tag, val, loc in zip(tags_i, vals_i, locs_i): + group = _init_group() + group.kpts[i] = loc + group.scores[i] = val + group.tag_list.append(tag) + + groups.append(group) + + else: # Match keypoints to existing groups + groups = groups[:max_groups] + group_tags = [np.mean(g.tag_list, axis=0) for g in groups] + + # Calculate distance matrix between group tags and tag candidates + # of the i-th keypoint + # Shape: (M', 1, L) , (1, G, L) -> (M', G, L) + diff = tags_i[:, None] - np.array(group_tags)[None] + dists = np.linalg.norm(diff, ord=2, axis=2) + num_kpts, num_groups = dists.shape[:2] + + # Experimental cost function for keypoint-group matching + costs = np.round(dists) * 100 - vals_i[..., None] + if num_kpts > num_groups: + padding = np.full((num_kpts, num_kpts - num_groups), + 1e10, + dtype=np.float32) + costs = np.concatenate((costs, padding), axis=1) + + # Match keypoints and groups by Munkres algorithm + matches = munkres.compute(costs) + for kpt_idx, group_idx in matches: + if group_idx < num_groups and dists[kpt_idx, + group_idx] < tag_thr: + # Add the keypoint to the matched group + group = groups[group_idx] + else: + # Initialize a new group with unmatched keypoint + group = _init_group() + groups.append(group) + + group.kpts[i] = locs_i[kpt_idx] + group.scores[i] = vals_i[kpt_idx] + group.tag_list.append(tags_i[kpt_idx]) + + groups = groups[:max_groups] + if groups: + grouped_keypoints = np.stack( + [np.r_['1', g.kpts, g.scores[:, None]] for g in groups]) + else: + grouped_keypoints = np.empty((0, K, D + 1)) + + return grouped_keypoints + + +@KEYPOINT_CODECS.register_module() +class AssociativeEmbedding(BaseKeypointCodec): + """Encode/decode keypoints with the method introduced in "Associative + Embedding". This is an asymmetric codec, where the keypoints are + represented as gaussian heatmaps and position indices during encoding, and + restored from predicted heatmaps and group tags. + + See the paper `Associative Embedding: End-to-End Learning for Joint + Detection and Grouping`_ by Newell et al (2017) for details + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - embedding tag dimension: L + - image size: [w, h] + - heatmap size: [W, H] + + Encoded: + + - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) + where [W, H] is the `heatmap_size` + - keypoint_indices (np.ndarray): The keypoint position indices in shape + (N, K, 2). Each keypoint's index is [i, v], where i is the position + index in the heatmap (:math:`i=y*w+x`) and v is the visibility + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + sigma (float): The sigma value of the Gaussian heatmap + use_udp (bool): Whether use unbiased data processing. See + `UDP (CVPR 2020)`_ for details. Defaults to ``False`` + decode_keypoint_order (List[int]): The grouping order of the + keypoint indices. The groupping usually starts from a keypoints + around the head and torso, and gruadually moves out to the limbs + decode_keypoint_thr (float): The threshold of keypoint response value + in heatmaps. Defaults to 0.1 + decode_tag_thr (float): The maximum allowed tag distance when matching + a keypoint to a group. A keypoint with larger tag distance to any + of the existing groups will initializes a new group. Defaults to + 1.0 + decode_nms_kernel (int): The kernel size of the NMS during decoding, + which should be an odd integer. Defaults to 5 + decode_gaussian_kernel (int): The kernel size of the Gaussian blur + during decoding, which should be an odd integer. It is only used + when ``self.use_udp==True``. Defaults to 3 + decode_topk (int): The number top-k candidates of each keypoints that + will be retrieved from the heatmaps during dedocding. Defaults to + 20 + decode_max_instances (int, optional): The maximum number of instances + to decode. ``None`` means no limitation to the instance number. + Defaults to ``None`` + + .. _`Associative Embedding: End-to-End Learning for Joint Detection and + Grouping`: https://arxiv.org/abs/1611.05424 + .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 + """ + + def __init__( + self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + sigma: Optional[float] = None, + use_udp: bool = False, + decode_keypoint_order: List[int] = [], + decode_nms_kernel: int = 5, + decode_gaussian_kernel: int = 3, + decode_keypoint_thr: float = 0.1, + decode_tag_thr: float = 1.0, + decode_topk: int = 20, + decode_max_instances: Optional[int] = None, + ) -> None: + super().__init__() + self.input_size = input_size + self.heatmap_size = heatmap_size + self.use_udp = use_udp + self.decode_nms_kernel = decode_nms_kernel + self.decode_gaussian_kernel = decode_gaussian_kernel + self.decode_keypoint_thr = decode_keypoint_thr + self.decode_tag_thr = decode_tag_thr + self.decode_topk = decode_topk + self.decode_max_instances = decode_max_instances + self.dedecode_keypoint_order = decode_keypoint_order.copy() + + if self.use_udp: + self.scale_factor = ((np.array(input_size) - 1) / + (np.array(heatmap_size) - 1)).astype( + np.float32) + else: + self.scale_factor = (np.array(input_size) / + heatmap_size).astype(np.float32) + + if sigma is None: + sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64 + self.sigma = sigma + + def encode( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Encode keypoints into heatmaps and position indices. Note that the + original keypoint coordinates should be in the input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_indices (np.ndarray): The keypoint position indices + in shape (N, K, 2). Each keypoint's index is [i, v], where i + is the position index in the heatmap (:math:`i=y*w+x`) and v + is the visibility + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + # keypoint coordinates in heatmap + _keypoints = keypoints / self.scale_factor + + if self.use_udp: + heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + else: + heatmaps, keypoint_weights = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + + keypoint_indices = self._encode_keypoint_indices( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible) + + encoded = dict( + heatmaps=heatmaps, + keypoint_indices=keypoint_indices, + keypoint_weights=keypoint_weights) + + return encoded + + def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray) -> np.ndarray: + w, h = heatmap_size + N, K, _ = keypoints.shape + keypoint_indices = np.zeros((N, K, 2), dtype=np.int64) + + for n, k in product(range(N), range(K)): + x, y = (keypoints[n, k] + 0.5).astype(np.int64) + index = y * w + x + vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h) + keypoint_indices[n, k] = [index, vis] + + return keypoint_indices + + def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: + raise NotImplementedError() + + def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, + k: int): + """Get top-k response values from the heatmaps and corresponding tag + values from the tagging heatmaps. + + Args: + batch_heatmaps (Tensor): Keypoint detection heatmaps in shape + (B, K, H, W) + batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where + the tag dim C is 2*K when using flip testing, or K otherwise + k (int): The number of top responses to get + + Returns: + tuple: + - topk_vals (Tensor): Top-k response values of each heatmap in + shape (B, K, Topk) + - topk_tags (Tensor): The corresponding embedding tags of the + top-k responses, in shape (B, K, Topk, L) + - topk_locs (Tensor): The location of the top-k responses in each + heatmap, in shape (B, K, Topk, 2) where last dimension + represents x and y coordinates + """ + B, K, H, W = batch_heatmaps.shape + L = batch_tags.shape[1] // K + + # shape of topk_val, top_indices: (B, K, TopK) + topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( + k, dim=-1) + + topk_tags_per_kpts = [ + torch.gather(_tag, dim=2, index=topk_indices) + for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1) + ] + + topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L) + topk_locs = torch.stack([topk_indices % W, topk_indices // W], + dim=-1) # (B, K, TopK, 2) + + return topk_vals, topk_tags, topk_locs + + def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray, + batch_locs: np.ndarray): + """Group keypoints into groups (each represents an instance) by tags. + + Args: + batch_vals (Tensor): Heatmap response values of keypoint + candidates in shape (B, K, Topk) + batch_tags (Tensor): Tags of keypoint candidates in shape + (B, K, Topk, L) + batch_locs (Tensor): Locations of keypoint candidates in shape + (B, K, Topk, 2) + + Returns: + List[np.ndarray]: Grouping results of a batch, each element is a + np.ndarray (in shape [N, K, D+1]) that contains the groups + detected in an image, including both keypoint coordinates and + scores. + """ + + def _group_func(inputs: Tuple): + vals, tags, locs = inputs + return _group_keypoints_by_tags( + vals, + tags, + locs, + keypoint_order=self.dedecode_keypoint_order, + val_thr=self.decode_keypoint_thr, + tag_thr=self.decode_tag_thr, + max_groups=self.decode_max_instances) + + _results = map(_group_func, zip(batch_vals, batch_tags, batch_locs)) + results = list(_results) + return results + + def _fill_missing_keypoints(self, keypoints: np.ndarray, + keypoint_scores: np.ndarray, + heatmaps: np.ndarray, tags: np.ndarray): + """Fill the missing keypoints in the initial predictions. + + Args: + keypoints (np.ndarray): Keypoint predictions in shape (N, K, D) + keypoint_scores (np.ndarray): Keypint score predictions in shape + (N, K), in which 0 means the corresponding keypoint is + missing in the initial prediction + heatmaps (np.ndarry): Heatmaps in shape (K, H, W) + tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where + C=L*K + + Returns: + tuple: + - keypoints (np.ndarray): Keypoint predictions with missing + ones filled + - keypoint_scores (np.ndarray): Keypoint score predictions with + missing ones filled + """ + + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + L = tags.shape[0] // K + keypoint_tags = [tags[k::K] for k in range(K)] + + for n in range(N): + # Calculate the instance tag (mean tag of detected keypoints) + _tag = [] + for k in range(K): + if keypoint_scores[n, k] > 0: + x, y = keypoints[n, k, :2].astype(np.int64) + x = np.clip(x, 0, W - 1) + y = np.clip(y, 0, H - 1) + _tag.append(keypoint_tags[k][:, y, x]) + + tag = np.mean(_tag, axis=0) + tag = tag.reshape(L, 1, 1) + # Search maximum response of the missing keypoints + for k in range(K): + if keypoint_scores[n, k] > 0: + continue + dist_map = np.linalg.norm( + keypoint_tags[k] - tag, ord=2, axis=0) + cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W + y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W)) + keypoints[n, k] = [x, y] + keypoint_scores[n, k] = heatmaps[k, y, x] + + return keypoints, keypoint_scores + + def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Decode the keypoint coordinates from a batch of heatmaps and tagging + heatmaps. The decoded keypoint coordinates are in the input image + space. + + Args: + batch_heatmaps (Tensor): Keypoint detection heatmaps in shape + (B, K, H, W) + batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where + :math:`C=L*K` + + Returns: + tuple: + - batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates + of the batch, each is in shape (N, K, D) + - batch_scores (List[np.ndarray]): Decoded keypoint scores of the + batch, each is in shape (N, K). It usually represents the + confidience of the keypoint prediction + """ + B, _, H, W = batch_heatmaps.shape + assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), ( + f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and ' + f'tagging map ({batch_tags.shape})') + + # Heatmap NMS + batch_heatmaps = batch_heatmap_nms(batch_heatmaps, + self.decode_nms_kernel) + + # Get top-k in each heatmap and and convert to numpy + batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( + self._get_batch_topk( + batch_heatmaps, batch_tags, k=self.decode_topk)) + + # Group keypoint candidates into groups (instances) + batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags, + batch_topk_locs) + + # Convert to numpy + batch_heatmaps_np = to_numpy(batch_heatmaps) + batch_tags_np = to_numpy(batch_tags) + + # Refine the keypoint prediction + batch_keypoints = [] + batch_keypoint_scores = [] + for i, (groups, heatmaps, tags) in enumerate( + zip(batch_groups, batch_heatmaps_np, batch_tags_np)): + + keypoints, scores = groups[..., :-1], groups[..., -1] + + if keypoints.size > 0: + # identify missing keypoints + keypoints, scores = self._fill_missing_keypoints( + keypoints, scores, heatmaps, tags) + + # refine keypoint coordinates according to heatmap distribution + if self.use_udp: + keypoints = refine_keypoints_dark_udp( + keypoints, + heatmaps, + blur_kernel_size=self.decode_gaussian_kernel) + else: + keypoints = refine_keypoints(keypoints, heatmaps) + + batch_keypoints.append(keypoints) + batch_keypoint_scores.append(scores) + + # restore keypoint scale + batch_keypoints = [ + kpts * self.scale_factor for kpts in batch_keypoints + ] + + return batch_keypoints, batch_keypoint_scores diff --git a/mmpose/codecs/base.py b/mmpose/codecs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d8479fdf1e1c5d51c0c0a2722c54b8a6c018c113 --- /dev/null +++ b/mmpose/codecs/base.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Any, List, Optional, Tuple + +import numpy as np +from mmengine.utils import is_method_overridden + + +class BaseKeypointCodec(metaclass=ABCMeta): + """The base class of the keypoint codec. + + A keypoint codec is a module to encode keypoint coordinates to specific + representation (e.g. heatmap) and vice versa. A subclass should implement + the methods :meth:`encode` and :meth:`decode`. + """ + + # pass additional encoding arguments to the `encode` method, beyond the + # mandatory `keypoints` and `keypoints_visible` arguments. + auxiliary_encode_keys = set() + + @abstractmethod + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encode keypoints. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + + Returns: + dict: Encoded items. + """ + + @abstractmethod + def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoints. + + Args: + encoded (any): Encoded keypoint representation using the codec + + Returns: + tuple: + - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + """ + + def batch_decode(self, batch_encoded: Any + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Decode keypoints. + + Args: + batch_encoded (any): A batch of encoded keypoint + representations + + Returns: + tuple: + - batch_keypoints (List[np.ndarray]): Each element is keypoint + coordinates in shape (N, K, D) + - batch_keypoints (List[np.ndarray]): Each element is keypoint + visibility in shape (N, K) + """ + raise NotImplementedError() + + @property + def support_batch_decoding(self) -> bool: + """Return whether the codec support decoding from batch data.""" + return is_method_overridden('batch_decode', BaseKeypointCodec, + self.__class__) diff --git a/mmpose/codecs/decoupled_heatmap.py b/mmpose/codecs/decoupled_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..da38a4ce2c825f5f362b19ee09d2515720c1a1f3 --- /dev/null +++ b/mmpose/codecs/decoupled_heatmap.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec +from .utils import (generate_gaussian_heatmaps, get_diagonal_lengths, + get_instance_bbox, get_instance_root) +from .utils.post_processing import get_heatmap_maximum +from .utils.refinement import refine_keypoints + + +@KEYPOINT_CODECS.register_module() +class DecoupledHeatmap(BaseKeypointCodec): + """Encode/decode keypoints with the method introduced in the paper CID. + + See the paper Contextual Instance Decoupling for Robust Multi-Person + Pose Estimation`_ by Wang et al (2022) for details + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + - heatmap size: [W, H] + + Encoded: + - heatmaps (np.ndarray): The coupled heatmap in shape + (1+K, H, W) where [W, H] is the `heatmap_size`. + - instance_heatmaps (np.ndarray): The decoupled heatmap in shape + (M*K, H, W) where M is the number of instances. + - keypoint_weights (np.ndarray): The weight for heatmaps in shape + (M*K). + - instance_coords (np.ndarray): The coordinates of instance roots + in shape (M, 2) + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + root_type (str): The method to generate the instance root. Options + are: + + - ``'kpt_center'``: Average coordinate of all visible keypoints. + - ``'bbox_center'``: Center point of bounding boxes outlined by + all visible keypoints. + + Defaults to ``'kpt_center'`` + + heatmap_min_overlap (float): Minimum overlap rate among instances. + Used when calculating sigmas for instances. Defaults to 0.7 + background_weight (float): Loss weight of background pixels. + Defaults to 0.1 + encode_max_instances (int): The maximum number of instances + to encode for each sample. Defaults to 30 + + .. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_ + Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_ + CVPR_2022_paper.html + """ + + # DecoupledHeatmap requires bounding boxes to determine the size of each + # instance, so that it can assign varying sigmas based on their size + auxiliary_encode_keys = {'bbox'} + + def __init__( + self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + root_type: str = 'kpt_center', + heatmap_min_overlap: float = 0.7, + encode_max_instances: int = 30, + ): + super().__init__() + + self.input_size = input_size + self.heatmap_size = heatmap_size + self.root_type = root_type + self.encode_max_instances = encode_max_instances + self.heatmap_min_overlap = heatmap_min_overlap + + self.scale_factor = (np.array(input_size) / + heatmap_size).astype(np.float32) + + def _get_instance_wise_sigmas( + self, + bbox: np.ndarray, + ) -> np.ndarray: + """Get sigma values for each instance according to their size. + + Args: + bbox (np.ndarray): Bounding box in shape (N, 4, 2) + + Returns: + np.ndarray: Array containing the sigma values for each instance. + """ + sigmas = np.zeros((bbox.shape[0], ), dtype=np.float32) + + heights = np.sqrt(np.power(bbox[:, 0] - bbox[:, 1], 2).sum(axis=-1)) + widths = np.sqrt(np.power(bbox[:, 0] - bbox[:, 2], 2).sum(axis=-1)) + + for i in range(bbox.shape[0]): + h, w = heights[i], widths[i] + + # compute sigma for each instance + # condition 1 + a1, b1 = 1, h + w + c1 = w * h * (1 - self.heatmap_min_overlap) / ( + 1 + self.heatmap_min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + # condition 2 + a2 = 4 + b2 = 2 * (h + w) + c2 = (1 - self.heatmap_min_overlap) * w * h + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + # condition 3 + a3 = 4 * self.heatmap_min_overlap + b3 = -2 * self.heatmap_min_overlap * (h + w) + c3 = (self.heatmap_min_overlap - 1) * w * h + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + + sigmas[i] = min(r1, r2, r3) / 3 + + return sigmas + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + bbox: Optional[np.ndarray] = None) -> dict: + """Encode keypoints into heatmaps. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + bbox (np.ndarray): Bounding box in shape (N, 8) which includes + coordinates of 4 corners. + + Returns: + dict: + - heatmaps (np.ndarray): The coupled heatmap in shape + (1+K, H, W) where [W, H] is the `heatmap_size`. + - instance_heatmaps (np.ndarray): The decoupled heatmap in shape + (N*K, H, W) where M is the number of instances. + - keypoint_weights (np.ndarray): The weight for heatmaps in shape + (N*K). + - instance_coords (np.ndarray): The coordinates of instance roots + in shape (N, 2) + """ + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + if bbox is None: + # generate pseudo bbox via visible keypoints + bbox = get_instance_bbox(keypoints, keypoints_visible) + bbox = np.tile(bbox, 2).reshape(-1, 4, 2) + # corner order: left_top, left_bottom, right_top, right_bottom + bbox[:, 1:3, 0] = bbox[:, 0:2, 0] + + # keypoint coordinates in heatmap + _keypoints = keypoints / self.scale_factor + _bbox = bbox.reshape(-1, 4, 2) / self.scale_factor + + # compute the root and scale of each instance + roots, roots_visible = get_instance_root(_keypoints, keypoints_visible, + self.root_type) + + sigmas = self._get_instance_wise_sigmas(_bbox) + + # generate global heatmaps + heatmaps, keypoint_weights = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=np.concatenate((_keypoints, roots[:, None]), axis=1), + keypoints_visible=np.concatenate( + (keypoints_visible, roots_visible[:, None]), axis=1), + sigma=sigmas) + roots_visible = keypoint_weights[:, -1] + + # select instances + inst_roots, inst_indices = [], [] + diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible) + for i in np.argsort(diagonal_lengths): + if roots_visible[i] < 1: + continue + # rand root point in 3x3 grid + x, y = roots[i] + np.random.randint(-1, 2, (2, )) + x = max(0, min(x, self.heatmap_size[0] - 1)) + y = max(0, min(y, self.heatmap_size[1] - 1)) + if (x, y) not in inst_roots: + inst_roots.append((x, y)) + inst_indices.append(i) + if len(inst_indices) > self.encode_max_instances: + rand_indices = random.sample( + range(len(inst_indices)), self.encode_max_instances) + inst_roots = [inst_roots[i] for i in rand_indices] + inst_indices = [inst_indices[i] for i in rand_indices] + + # generate instance-wise heatmaps + inst_heatmaps, inst_heatmap_weights = [], [] + for i in inst_indices: + inst_heatmap, inst_heatmap_weight = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=_keypoints[i:i + 1], + keypoints_visible=keypoints_visible[i:i + 1], + sigma=sigmas[i].item()) + inst_heatmaps.append(inst_heatmap) + inst_heatmap_weights.append(inst_heatmap_weight) + + if len(inst_indices) > 0: + inst_heatmaps = np.concatenate(inst_heatmaps) + inst_heatmap_weights = np.concatenate(inst_heatmap_weights) + inst_roots = np.array(inst_roots, dtype=np.int32) + else: + inst_heatmaps = np.empty((0, *self.heatmap_size[::-1])) + inst_heatmap_weights = np.empty((0, )) + inst_roots = np.empty((0, 2), dtype=np.int32) + + encoded = dict( + heatmaps=heatmaps, + instance_heatmaps=inst_heatmaps, + keypoint_weights=inst_heatmap_weights, + instance_coords=inst_roots) + + return encoded + + def decode(self, instance_heatmaps: np.ndarray, + instance_scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from decoupled heatmaps. The decoded + keypoint coordinates are in the input image space. + + Args: + instance_heatmaps (np.ndarray): Heatmaps in shape (N, K, H, W) + instance_scores (np.ndarray): Confidence of instance roots + prediction in shape (N, 1) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded keypoint coordinates in shape + (N, K, D) + - scores (np.ndarray): The keypoint scores in shape (N, K). It + usually represents the confidence of the keypoint prediction + """ + keypoints, keypoint_scores = [], [] + + for i in range(instance_heatmaps.shape[0]): + heatmaps = instance_heatmaps[i].copy() + kpts, scores = get_heatmap_maximum(heatmaps) + keypoints.append(refine_keypoints(kpts[None], heatmaps)) + keypoint_scores.append(scores[None]) + + keypoints = np.concatenate(keypoints) + # Restore the keypoint scale + keypoints = keypoints * self.scale_factor + + keypoint_scores = np.concatenate(keypoint_scores) + keypoint_scores *= instance_scores + + return keypoints, keypoint_scores diff --git a/mmpose/codecs/integral_regression_label.py b/mmpose/codecs/integral_regression_label.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8e72cb104bde62353d5e5c08fd40a2b8e635f6 --- /dev/null +++ b/mmpose/codecs/integral_regression_label.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec +from .msra_heatmap import MSRAHeatmap +from .regression_label import RegressionLabel + + +@KEYPOINT_CODECS.register_module() +class IntegralRegressionLabel(BaseKeypointCodec): + """Generate keypoint coordinates and normalized heatmaps. See the paper: + `DSNT`_ by Nibali et al(2018). + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + + Encoded: + + - keypoint_labels (np.ndarray): The normalized regression labels in + shape (N, K, D) where D is 2 for 2d coordinates + - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) where + [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Input image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + sigma (float): The sigma value of the Gaussian heatmap + unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'`` + encoding. See `Dark Pose`_ for details. Defaults to ``False`` + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation in DarkPose. The kernel size and sigma should follow + the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`. + Defaults to 11 + normalize (bool): Whether to normalize the heatmaps. Defaults to True. + + .. _`DSNT`: https://arxiv.org/abs/1801.07372 + """ + + def __init__(self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + sigma: float, + unbiased: bool = False, + blur_kernel_size: int = 11, + normalize: bool = True) -> None: + super().__init__() + + self.heatmap_codec = MSRAHeatmap(input_size, heatmap_size, sigma, + unbiased, blur_kernel_size) + self.keypoint_codec = RegressionLabel(input_size) + self.normalize = normalize + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encoding keypoints to regression labels and heatmaps. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - keypoint_labels (np.ndarray): The normalized regression labels in + shape (N, K, D) where D is 2 for 2d coordinates + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + encoded_hm = self.heatmap_codec.encode(keypoints, keypoints_visible) + encoded_kp = self.keypoint_codec.encode(keypoints, keypoints_visible) + + heatmaps = encoded_hm['heatmaps'] + keypoint_labels = encoded_kp['keypoint_labels'] + keypoint_weights = encoded_kp['keypoint_weights'] + + if self.normalize: + val_sum = heatmaps.sum(axis=(-1, -2)).reshape(-1, 1, 1) + 1e-24 + heatmaps = heatmaps / val_sum + + encoded = dict( + keypoint_labels=keypoint_labels, + heatmaps=heatmaps, + keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (N, K, D) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D) + - socres (np.ndarray): The keypoint scores in shape (N, K). + It usually represents the confidence of the keypoint prediction + """ + + keypoints, scores = self.keypoint_codec.decode(encoded) + + return keypoints, scores diff --git a/mmpose/codecs/megvii_heatmap.py b/mmpose/codecs/megvii_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..e898004637b6d804cb256de369a64ed4d41e560d --- /dev/null +++ b/mmpose/codecs/megvii_heatmap.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec +from .utils import gaussian_blur, get_heatmap_maximum + + +@KEYPOINT_CODECS.register_module() +class MegviiHeatmap(BaseKeypointCodec): + """Represent keypoints as heatmaps via "Megvii" approach. See `MSPN`_ + (2019) and `CPN`_ (2018) for details. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + - heatmap size: [W, H] + + Encoded: + + - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) + where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + kernel_size (tuple): The kernel size of the heatmap gaussian in + [ks_x, ks_y] + + .. _`MSPN`: https://arxiv.org/abs/1901.00148 + .. _`CPN`: https://arxiv.org/abs/1711.07319 + """ + + def __init__( + self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + kernel_size: int, + ) -> None: + + super().__init__() + self.input_size = input_size + self.heatmap_size = heatmap_size + self.kernel_size = kernel_size + self.scale_factor = (np.array(input_size) / + heatmap_size).astype(np.float32) + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encode keypoints into heatmaps. Note that the original keypoint + coordinates should be in the input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + + N, K, _ = keypoints.shape + W, H = self.heatmap_size + + assert N == 1, ( + f'{self.__class__.__name__} only support single-instance ' + 'keypoint encoding') + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + # get center coordinates + kx, ky = (keypoints[n, k] / self.scale_factor).astype(np.int64) + if kx < 0 or kx >= W or ky < 0 or ky >= H: + keypoint_weights[n, k] = 0 + continue + + heatmaps[k, ky, kx] = 1. + kernel_size = (self.kernel_size, self.kernel_size) + heatmaps[k] = cv2.GaussianBlur(heatmaps[k], kernel_size, 0) + + # normalize the heatmap + heatmaps[k] = heatmaps[k] / heatmaps[k, ky, kx] * 255. + + encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from heatmaps. The decoded keypoint + coordinates are in the input image space. + + Args: + encoded (np.ndarray): Heatmaps in shape (K, H, W) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded keypoint coordinates in shape + (K, D) + - scores (np.ndarray): The keypoint scores in shape (K,). It + usually represents the confidence of the keypoint prediction + """ + heatmaps = gaussian_blur(encoded.copy(), self.kernel_size) + K, H, W = heatmaps.shape + + keypoints, scores = get_heatmap_maximum(heatmaps) + + for k in range(K): + heatmap = heatmaps[k] + px = int(keypoints[k, 0]) + py = int(keypoints[k, 1]) + if 1 < px < W - 1 and 1 < py < H - 1: + diff = np.array([ + heatmap[py][px + 1] - heatmap[py][px - 1], + heatmap[py + 1][px] - heatmap[py - 1][px] + ]) + keypoints[k] += (np.sign(diff) * 0.25 + 0.5) + + scores = scores / 255.0 + 0.5 + + # Unsqueeze the instance dimension for single-instance results + # and restore the keypoint scales + keypoints = keypoints[None] * self.scale_factor + scores = scores[None] + + return keypoints, scores diff --git a/mmpose/codecs/msra_heatmap.py b/mmpose/codecs/msra_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..63ba292e4de213d3151c6b0668257f1e77f3d195 --- /dev/null +++ b/mmpose/codecs/msra_heatmap.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec +from .utils.gaussian_heatmap import (generate_gaussian_heatmaps, + generate_unbiased_gaussian_heatmaps) +from .utils.post_processing import get_heatmap_maximum +from .utils.refinement import refine_keypoints, refine_keypoints_dark + + +@KEYPOINT_CODECS.register_module() +class MSRAHeatmap(BaseKeypointCodec): + """Represent keypoints as heatmaps via "MSRA" approach. See the paper: + `Simple Baselines for Human Pose Estimation and Tracking`_ by Xiao et al + (2018) for details. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + - heatmap size: [W, H] + + Encoded: + + - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) + where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + sigma (float): The sigma value of the Gaussian heatmap + unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'`` + encoding. See `Dark Pose`_ for details. Defaults to ``False`` + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation in DarkPose. The kernel size and sigma should follow + the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`. + Defaults to 11 + + .. _`Simple Baselines for Human Pose Estimation and Tracking`: + https://arxiv.org/abs/1804.06208 + .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 + """ + + def __init__(self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + sigma: float, + unbiased: bool = False, + blur_kernel_size: int = 11) -> None: + super().__init__() + self.input_size = input_size + self.heatmap_size = heatmap_size + self.sigma = sigma + self.unbiased = unbiased + + # The Gaussian blur kernel size of the heatmap modulation + # in DarkPose and the sigma value follows the expirical + # formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8` + # which gives: + # sigma~=3 if ks=17 + # sigma=2 if ks=11; + # sigma~=1.5 if ks=7; + # sigma~=1 if ks=3; + self.blur_kernel_size = blur_kernel_size + self.scale_factor = (np.array(input_size) / + heatmap_size).astype(np.float32) + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encode keypoints into heatmaps. Note that the original keypoint + coordinates should be in the input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + + assert keypoints.shape[0] == 1, ( + f'{self.__class__.__name__} only support single-instance ' + 'keypoint encoding') + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if self.unbiased: + heatmaps, keypoint_weights = generate_unbiased_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=keypoints / self.scale_factor, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + else: + heatmaps, keypoint_weights = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=keypoints / self.scale_factor, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + + encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from heatmaps. The decoded keypoint + coordinates are in the input image space. + + Args: + encoded (np.ndarray): Heatmaps in shape (K, H, W) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded keypoint coordinates in shape + (N, K, D) + - scores (np.ndarray): The keypoint scores in shape (N, K). It + usually represents the confidence of the keypoint prediction + """ + heatmaps = encoded.copy() + K, H, W = heatmaps.shape + + keypoints, scores = get_heatmap_maximum(heatmaps) + + # Unsqueeze the instance dimension for single-instance results + keypoints, scores = keypoints[None], scores[None] + + if self.unbiased: + # Alleviate biased coordinate + keypoints = refine_keypoints_dark( + keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size) + + else: + keypoints = refine_keypoints(keypoints, heatmaps) + + # Restore the keypoint scale + keypoints = keypoints * self.scale_factor + + return keypoints, scores diff --git a/mmpose/codecs/regression_label.py b/mmpose/codecs/regression_label.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae385d2d97cebf3a7c097318b76a14626b0db89 --- /dev/null +++ b/mmpose/codecs/regression_label.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class RegressionLabel(BaseKeypointCodec): + r"""Generate keypoint coordinates. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + + Encoded: + + - keypoint_labels (np.ndarray): The normalized regression labels in + shape (N, K, D) where D is 2 for 2d coordinates + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Input image size in [w, h] + + """ + + def __init__(self, input_size: Tuple[int, int]) -> None: + super().__init__() + + self.input_size = input_size + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encoding keypoints from input image space to normalized space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - keypoint_labels (np.ndarray): The normalized regression labels in + shape (N, K, D) where D is 2 for 2d coordinates + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + w, h = self.input_size + valid = ((keypoints >= 0) & + (keypoints <= [w - 1, h - 1])).all(axis=-1) & ( + keypoints_visible > 0.5) + + keypoint_labels = (keypoints / np.array([w, h])).astype(np.float32) + keypoint_weights = np.where(valid, 1., 0.).astype(np.float32) + + encoded = dict( + keypoint_labels=keypoint_labels, keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (N, K, D) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D) + - socres (np.ndarray): The keypoint scores in shape (N, K). + It usually represents the confidence of the keypoint prediction + """ + + if encoded.shape[-1] == 2: + N, K, _ = encoded.shape + normalized_coords = encoded.copy() + scores = np.ones((N, K), dtype=np.float32) + elif encoded.shape[-1] == 4: + # split coords and sigma if outputs contain output_sigma + normalized_coords = encoded[..., :2].copy() + output_sigma = encoded[..., 2:4].copy() + + scores = (1 - output_sigma).mean(axis=-1) + else: + raise ValueError( + 'Keypoint dimension should be 2 or 4 (with sigma), ' + f'but got {encoded.shape[-1]}') + + w, h = self.input_size + keypoints = normalized_coords * np.array([w, h]) + + return keypoints, scores diff --git a/mmpose/codecs/simcc_label.py b/mmpose/codecs/simcc_label.py new file mode 100644 index 0000000000000000000000000000000000000000..a22498c35291dfcd838ce1fe6451bd60dd0e6d48 --- /dev/null +++ b/mmpose/codecs/simcc_label.py @@ -0,0 +1,286 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import Optional, Tuple, Union + +import numpy as np + +from mmpose.codecs.utils import get_simcc_maximum +from mmpose.codecs.utils.refinement import refine_simcc_dark +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class SimCCLabel(BaseKeypointCodec): + r"""Generate keypoint representation via "SimCC" approach. + See the paper: `SimCC: a Simple Coordinate Classification Perspective for + Human Pose Estimation`_ by Li et al (2022) for more details. + Old name: SimDR + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + + Encoded: + + - keypoint_x_labels (np.ndarray): The generated SimCC label for x-axis. + The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'`` + and (N, K) if `smoothing_type=='standard'``, where + :math:`Wx=w*simcc_split_ratio` + - keypoint_y_labels (np.ndarray): The generated SimCC label for y-axis. + The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'`` + and (N, K) if `smoothing_type=='standard'``, where + :math:`Wy=h*simcc_split_ratio` + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Input image size in [w, h] + smoothing_type (str): The SimCC label smoothing strategy. Options are + ``'gaussian'`` and ``'standard'``. Defaults to ``'gaussian'`` + sigma (float | int | tuple): The sigma value in the Gaussian SimCC + label. Defaults to 6.0 + simcc_split_ratio (float): The ratio of the label size to the input + size. For example, if the input width is ``w``, the x label size + will be :math:`w*simcc_split_ratio`. Defaults to 2.0 + label_smooth_weight (float): Label Smoothing weight. Defaults to 0.0 + normalize (bool): Whether to normalize the heatmaps. Defaults to True. + + .. _`SimCC: a Simple Coordinate Classification Perspective for Human Pose + Estimation`: https://arxiv.org/abs/2107.03332 + """ + + def __init__(self, + input_size: Tuple[int, int], + smoothing_type: str = 'gaussian', + sigma: Union[float, int, Tuple[float]] = 6.0, + simcc_split_ratio: float = 2.0, + label_smooth_weight: float = 0.0, + normalize: bool = True, + use_dark: bool = False) -> None: + super().__init__() + + self.input_size = input_size + self.smoothing_type = smoothing_type + self.simcc_split_ratio = simcc_split_ratio + self.label_smooth_weight = label_smooth_weight + self.normalize = normalize + self.use_dark = use_dark + + if isinstance(sigma, (float, int)): + self.sigma = np.array([sigma, sigma]) + else: + self.sigma = np.array(sigma) + + if self.smoothing_type not in {'gaussian', 'standard'}: + raise ValueError( + f'{self.__class__.__name__} got invalid `smoothing_type` value' + f'{self.smoothing_type}. Should be one of ' + '{"gaussian", "standard"}') + + if self.smoothing_type == 'gaussian' and self.label_smooth_weight > 0: + raise ValueError('Attribute `label_smooth_weight` is only ' + 'used for `standard` mode.') + + if self.label_smooth_weight < 0.0 or self.label_smooth_weight > 1.0: + raise ValueError('`label_smooth_weight` should be in range [0, 1]') + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encoding keypoints into SimCC labels. Note that the original + keypoint coordinates should be in the input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - keypoint_x_labels (np.ndarray): The generated SimCC label for + x-axis. + The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'`` + and (N, K) if `smoothing_type=='standard'``, where + :math:`Wx=w*simcc_split_ratio` + - keypoint_y_labels (np.ndarray): The generated SimCC label for + y-axis. + The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'`` + and (N, K) if `smoothing_type=='standard'``, where + :math:`Wy=h*simcc_split_ratio` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if self.smoothing_type == 'gaussian': + x_labels, y_labels, keypoint_weights = self._generate_gaussian( + keypoints, keypoints_visible) + elif self.smoothing_type == 'standard': + x_labels, y_labels, keypoint_weights = self._generate_standard( + keypoints, keypoints_visible) + else: + raise ValueError( + f'{self.__class__.__name__} got invalid `smoothing_type` value' + f'{self.smoothing_type}. Should be one of ' + '{"gaussian", "standard"}') + + encoded = dict( + keypoint_x_labels=x_labels, + keypoint_y_labels=y_labels, + keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from SimCC representations. The decoded + coordinates are in the input image space. + + Args: + encoded (Tuple[np.ndarray, np.ndarray]): SimCC labels for x-axis + and y-axis + simcc_x (np.ndarray): SimCC label for x-axis + simcc_y (np.ndarray): SimCC label for y-axis + + Returns: + tuple: + - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D) + - socres (np.ndarray): The keypoint scores in shape (N, K). + It usually represents the confidence of the keypoint prediction + """ + + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + + # Unsqueeze the instance dimension for single-instance results + if keypoints.ndim == 2: + keypoints = keypoints[None, :] + scores = scores[None, :] + + if self.use_dark: + x_blur = int((self.sigma[0] * 20 - 7) // 3) + y_blur = int((self.sigma[1] * 20 - 7) // 3) + x_blur -= int((x_blur % 2) == 0) + y_blur -= int((y_blur % 2) == 0) + keypoints[:, :, 0] = refine_simcc_dark(keypoints[:, :, 0], simcc_x, + x_blur) + keypoints[:, :, 1] = refine_simcc_dark(keypoints[:, :, 1], simcc_y, + y_blur) + + keypoints /= self.simcc_split_ratio + + return keypoints, scores + + def _map_coordinates( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Mapping keypoint coordinates into SimCC space.""" + + keypoints_split = keypoints.copy() + keypoints_split = np.around(keypoints_split * self.simcc_split_ratio) + keypoints_split = keypoints_split.astype(np.int64) + keypoint_weights = keypoints_visible.copy() + + return keypoints_split, keypoint_weights + + def _generate_standard( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Encoding keypoints into SimCC labels with Standard Label Smoothing + strategy. + + Labels will be one-hot vectors if self.label_smooth_weight==0.0 + """ + + N, K, _ = keypoints.shape + w, h = self.input_size + W = np.around(w * self.simcc_split_ratio).astype(int) + H = np.around(h * self.simcc_split_ratio).astype(int) + + keypoints_split, keypoint_weights = self._map_coordinates( + keypoints, keypoints_visible) + + target_x = np.zeros((N, K, W), dtype=np.float32) + target_y = np.zeros((N, K, H), dtype=np.float32) + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + # get center coordinates + mu_x, mu_y = keypoints_split[n, k].astype(np.int64) + + # detect abnormal coords and assign the weight 0 + if mu_x >= W or mu_y >= H or mu_x < 0 or mu_y < 0: + keypoint_weights[n, k] = 0 + continue + + if self.label_smooth_weight > 0: + target_x[n, k] = self.label_smooth_weight / (W - 1) + target_y[n, k] = self.label_smooth_weight / (H - 1) + + target_x[n, k, mu_x] = 1.0 - self.label_smooth_weight + target_y[n, k, mu_y] = 1.0 - self.label_smooth_weight + + return target_x, target_y, keypoint_weights + + def _generate_gaussian( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Encoding keypoints into SimCC labels with Gaussian Label Smoothing + strategy.""" + + N, K, _ = keypoints.shape + w, h = self.input_size + W = np.around(w * self.simcc_split_ratio).astype(int) + H = np.around(h * self.simcc_split_ratio).astype(int) + + keypoints_split, keypoint_weights = self._map_coordinates( + keypoints, keypoints_visible) + + target_x = np.zeros((N, K, W), dtype=np.float32) + target_y = np.zeros((N, K, H), dtype=np.float32) + + # 3-sigma rule + radius = self.sigma * 3 + + # xy grid + x = np.arange(0, W, 1, dtype=np.float32) + y = np.arange(0, H, 1, dtype=np.float32) + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + mu = keypoints_split[n, k] + + # check that the gaussian has in-bounds part + left, top = mu - radius + right, bottom = mu + radius + 1 + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + mu_x, mu_y = mu + + target_x[n, k] = np.exp(-((x - mu_x)**2) / (2 * self.sigma[0]**2)) + target_y[n, k] = np.exp(-((y - mu_y)**2) / (2 * self.sigma[1]**2)) + + if self.normalize: + norm_value = self.sigma * np.sqrt(np.pi * 2) + target_x /= norm_value[0] + target_y /= norm_value[1] + + return target_x, target_y, keypoint_weights diff --git a/mmpose/codecs/spr.py b/mmpose/codecs/spr.py new file mode 100644 index 0000000000000000000000000000000000000000..add6f5715b52a42d58114bcb3d432ea38cfec29f --- /dev/null +++ b/mmpose/codecs/spr.py @@ -0,0 +1,299 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec +from .utils import (batch_heatmap_nms, generate_displacement_heatmap, + generate_gaussian_heatmaps, get_diagonal_lengths, + get_instance_root) + + +@KEYPOINT_CODECS.register_module() +class SPR(BaseKeypointCodec): + """Encode/decode keypoints with Structured Pose Representation (SPR). + + See the paper `Single-stage multi-person pose machines`_ + by Nie et al (2017) for details + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + - heatmap size: [W, H] + + Encoded: + + - heatmaps (np.ndarray): The generated heatmap in shape (1, H, W) + where [W, H] is the `heatmap_size`. If the keypoint heatmap is + generated together, the output heatmap shape is (K+1, H, W) + - heatmap_weights (np.ndarray): The target weights for heatmaps which + has same shape with heatmaps. + - displacements (np.ndarray): The dense keypoint displacement in + shape (K*2, H, W). + - displacement_weights (np.ndarray): The target weights for heatmaps + which has same shape with displacements. + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + sigma (float or tuple, optional): The sigma values of the Gaussian + heatmaps. If sigma is a tuple, it includes both sigmas for root + and keypoint heatmaps. ``None`` means the sigmas are computed + automatically from the heatmap size. Defaults to ``None`` + generate_keypoint_heatmaps (bool): Whether to generate Gaussian + heatmaps for each keypoint. Defaults to ``False`` + root_type (str): The method to generate the instance root. Options + are: + + - ``'kpt_center'``: Average coordinate of all visible keypoints. + - ``'bbox_center'``: Center point of bounding boxes outlined by + all visible keypoints. + + Defaults to ``'kpt_center'`` + + minimal_diagonal_length (int or float): The threshold of diagonal + length of instance bounding box. Small instances will not be + used in training. Defaults to 32 + background_weight (float): Loss weight of background pixels. + Defaults to 0.1 + decode_thr (float): The threshold of keypoint response value in + heatmaps. Defaults to 0.01 + decode_nms_kernel (int): The kernel size of the NMS during decoding, + which should be an odd integer. Defaults to 5 + decode_max_instances (int): The maximum number of instances + to decode. Defaults to 30 + + .. _`Single-stage multi-person pose machines`: + https://arxiv.org/abs/1908.09220 + """ + + def __init__( + self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + sigma: Optional[Union[float, Tuple[float]]] = None, + generate_keypoint_heatmaps: bool = False, + root_type: str = 'kpt_center', + minimal_diagonal_length: Union[int, float] = 5, + background_weight: float = 0.1, + decode_nms_kernel: int = 5, + decode_max_instances: int = 30, + decode_thr: float = 0.01, + ): + super().__init__() + + self.input_size = input_size + self.heatmap_size = heatmap_size + self.generate_keypoint_heatmaps = generate_keypoint_heatmaps + self.root_type = root_type + self.minimal_diagonal_length = minimal_diagonal_length + self.background_weight = background_weight + self.decode_nms_kernel = decode_nms_kernel + self.decode_max_instances = decode_max_instances + self.decode_thr = decode_thr + + self.scale_factor = (np.array(input_size) / + heatmap_size).astype(np.float32) + + if sigma is None: + sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32 + if generate_keypoint_heatmaps: + # sigma for root heatmap and keypoint heatmaps + self.sigma = (sigma, sigma // 2) + else: + self.sigma = (sigma, ) + else: + if not isinstance(sigma, (tuple, list)): + sigma = (sigma, ) + if generate_keypoint_heatmaps: + assert len(sigma) == 2, 'sigma for keypoints must be given ' \ + 'if `generate_keypoint_heatmaps` ' \ + 'is True. e.g. sigma=(4, 2)' + self.sigma = sigma + + def _get_heatmap_weights(self, + heatmaps, + fg_weight: float = 1, + bg_weight: float = 0): + """Generate weight array for heatmaps. + + Args: + heatmaps (np.ndarray): Root and keypoint (optional) heatmaps + fg_weight (float): Weight for foreground pixels. Defaults to 1.0 + bg_weight (float): Weight for background pixels. Defaults to 0.0 + + Returns: + np.ndarray: Heatmap weight array in the same shape with heatmaps + """ + heatmap_weights = np.ones(heatmaps.shape) * bg_weight + heatmap_weights[heatmaps > 0] = fg_weight + return heatmap_weights + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encode keypoints into root heatmaps and keypoint displacement + fields. Note that the original keypoint coordinates should be in the + input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - heatmaps (np.ndarray): The generated heatmap in shape + (1, H, W) where [W, H] is the `heatmap_size`. If keypoint + heatmaps are generated together, the shape is (K+1, H, W) + - heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps + which has same shape with `heatmaps` + - displacements (np.ndarray): The generated displacement fields in + shape (K*D, H, W). The vector on each pixels represents the + displacement of keypoints belong to the associated instance + from this pixel. + - displacement_weights (np.ndarray): The pixel-wise weight for + displacements which has same shape with `displacements` + """ + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + # keypoint coordinates in heatmap + _keypoints = keypoints / self.scale_factor + + # compute the root and scale of each instance + roots, roots_visible = get_instance_root(_keypoints, keypoints_visible, + self.root_type) + diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible) + + # discard the small instances + roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0 + + # generate heatmaps + heatmaps, _ = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=roots[:, None], + keypoints_visible=roots_visible[:, None], + sigma=self.sigma[0]) + heatmap_weights = self._get_heatmap_weights( + heatmaps, bg_weight=self.background_weight) + + if self.generate_keypoint_heatmaps: + keypoint_heatmaps, _ = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible, + sigma=self.sigma[1]) + + keypoint_heatmaps_weights = self._get_heatmap_weights( + keypoint_heatmaps, bg_weight=self.background_weight) + + heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0) + heatmap_weights = np.concatenate( + (keypoint_heatmaps_weights, heatmap_weights), axis=0) + + # generate displacements + displacements, displacement_weights = \ + generate_displacement_heatmap( + self.heatmap_size, + _keypoints, + keypoints_visible, + roots, + roots_visible, + diagonal_lengths, + self.sigma[0], + ) + + encoded = dict( + heatmaps=heatmaps, + heatmap_weights=heatmap_weights, + displacements=displacements, + displacement_weights=displacement_weights) + + return encoded + + def decode(self, heatmaps: Tensor, + displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]: + """Decode the keypoint coordinates from heatmaps and displacements. The + decoded keypoint coordinates are in the input image space. + + Args: + heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps + in shape (1, H, W) or (K+1, H, W) + displacements (Tensor): Encoded keypoints displacement fields + in shape (K*D, H, W) + + Returns: + tuple: + - keypoints (Tensor): Decoded keypoint coordinates in shape + (N, K, D) + - scores (tuple): + - root_scores (Tensor): The root scores in shape (N, ) + - keypoint_scores (Tensor): The keypoint scores in + shape (N, K). If keypoint heatmaps are not generated, + `keypoint_scores` will be `None` + """ + # heatmaps, displacements = encoded + _k, h, w = displacements.shape + k = _k // 2 + displacements = displacements.view(k, 2, h, w) + + # convert displacements to a dense keypoint prediction + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) + regular_grid = torch.stack([x, y], dim=0).to(displacements) + posemaps = (regular_grid[None] + displacements).flatten(2) + + # find local maximum on root heatmap + root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:], + self.decode_nms_kernel) + root_scores, pos_idx = root_heatmap_peaks.flatten().topk( + self.decode_max_instances) + mask = root_scores > self.decode_thr + root_scores, pos_idx = root_scores[mask], pos_idx[mask] + + keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous() + + if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k: + # compute scores for each keypoint + keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints) + else: + keypoint_scores = None + + keypoints = torch.cat([ + kpt * self.scale_factor[i] + for i, kpt in enumerate(keypoints.split(1, -1)) + ], + dim=-1) + return keypoints, (root_scores, keypoint_scores) + + def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor): + """Calculate the keypoint scores with keypoints heatmaps and + coordinates. + + Args: + heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W) + keypoints (Tensor): Keypoint coordinates in shape (N, K, D) + + Returns: + Tensor: Keypoint scores in [N, K] + """ + k, h, w = heatmaps.shape + keypoints = torch.stack(( + keypoints[..., 0] / (w - 1) * 2 - 1, + keypoints[..., 1] / (h - 1) * 2 - 1, + ), + dim=-1) + keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous() + + keypoint_scores = torch.nn.functional.grid_sample( + heatmaps.unsqueeze(1), keypoints, + padding_mode='border').view(k, -1).transpose(0, 1).contiguous() + + return keypoint_scores diff --git a/mmpose/codecs/udp_heatmap.py b/mmpose/codecs/udp_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..c38ea17be4327e6f7d433198c6a78d5ad2869342 --- /dev/null +++ b/mmpose/codecs/udp_heatmap.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec +from .utils import (generate_offset_heatmap, generate_udp_gaussian_heatmaps, + get_heatmap_maximum, refine_keypoints_dark_udp) + + +@KEYPOINT_CODECS.register_module() +class UDPHeatmap(BaseKeypointCodec): + r"""Generate keypoint heatmaps by Unbiased Data Processing (UDP). + See the paper: `The Devil is in the Details: Delving into Unbiased Data + Processing for Human Pose Estimation`_ by Huang et al (2020) for details. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + - heatmap size: [W, H] + + Encoded: + + - heatmap (np.ndarray): The generated heatmap in shape (C_out, H, W) + where [W, H] is the `heatmap_size`, and the C_out is the output + channel number which depends on the `heatmap_type`. If + `heatmap_type=='gaussian'`, C_out equals to keypoint number K; + if `heatmap_type=='combined'`, C_out equals to K*3 + (x_offset, y_offset and class label) + - keypoint_weights (np.ndarray): The target weights in shape (K,) + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + heatmap_type (str): The heatmap type to encode the keypoitns. Options + are: + + - ``'gaussian'``: Gaussian heatmap + - ``'combined'``: Combination of a binary label map and offset + maps for X and Y axes. + + sigma (float): The sigma value of the Gaussian heatmap when + ``heatmap_type=='gaussian'``. Defaults to 2.0 + radius_factor (float): The radius factor of the binary label + map when ``heatmap_type=='combined'``. The positive region is + defined as the neighbor of the keypoit with the radius + :math:`r=radius_factor*max(W, H)`. Defaults to 0.0546875 + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation in DarkPose. Defaults to 11 + + .. _`The Devil is in the Details: Delving into Unbiased Data Processing for + Human Pose Estimation`: https://arxiv.org/abs/1911.07524 + """ + + def __init__(self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + heatmap_type: str = 'gaussian', + sigma: float = 2., + radius_factor: float = 0.0546875, + blur_kernel_size: int = 11) -> None: + super().__init__() + self.input_size = input_size + self.heatmap_size = heatmap_size + self.sigma = sigma + self.radius_factor = radius_factor + self.heatmap_type = heatmap_type + self.blur_kernel_size = blur_kernel_size + self.scale_factor = ((np.array(input_size) - 1) / + (np.array(heatmap_size) - 1)).astype(np.float32) + + if self.heatmap_type not in {'gaussian', 'combined'}: + raise ValueError( + f'{self.__class__.__name__} got invalid `heatmap_type` value' + f'{self.heatmap_type}. Should be one of ' + '{"gaussian", "combined"}') + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + """Encode keypoints into heatmaps. Note that the original keypoint + coordinates should be in the input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + dict: + - heatmap (np.ndarray): The generated heatmap in shape + (C_out, H, W) where [W, H] is the `heatmap_size`, and the + C_out is the output channel number which depends on the + `heatmap_type`. If `heatmap_type=='gaussian'`, C_out equals to + keypoint number K; if `heatmap_type=='combined'`, C_out + equals to K*3 (x_offset, y_offset and class label) + - keypoint_weights (np.ndarray): The target weights in shape + (K,) + """ + assert keypoints.shape[0] == 1, ( + f'{self.__class__.__name__} only support single-instance ' + 'keypoint encoding') + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if self.heatmap_type == 'gaussian': + heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=keypoints / self.scale_factor, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + elif self.heatmap_type == 'combined': + heatmaps, keypoint_weights = generate_offset_heatmap( + heatmap_size=self.heatmap_size, + keypoints=keypoints / self.scale_factor, + keypoints_visible=keypoints_visible, + radius_factor=self.radius_factor) + else: + raise ValueError( + f'{self.__class__.__name__} got invalid `heatmap_type` value' + f'{self.heatmap_type}. Should be one of ' + '{"gaussian", "combined"}') + + encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from heatmaps. The decoded keypoint + coordinates are in the input image space. + + Args: + encoded (np.ndarray): Heatmaps in shape (K, H, W) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded keypoint coordinates in shape + (N, K, D) + - scores (np.ndarray): The keypoint scores in shape (N, K). It + usually represents the confidence of the keypoint prediction + """ + heatmaps = encoded.copy() + + if self.heatmap_type == 'gaussian': + keypoints, scores = get_heatmap_maximum(heatmaps) + # unsqueeze the instance dimension for single-instance results + keypoints = keypoints[None] + scores = scores[None] + + keypoints = refine_keypoints_dark_udp( + keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size) + + elif self.heatmap_type == 'combined': + _K, H, W = heatmaps.shape + K = _K // 3 + + for cls_heatmap in heatmaps[::3]: + # Apply Gaussian blur on classification maps + ks = 2 * self.blur_kernel_size + 1 + cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap) + + # valid radius + radius = self.radius_factor * max(W, H) + + x_offset = heatmaps[1::3].flatten() * radius + y_offset = heatmaps[2::3].flatten() * radius + keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3]) + index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten() + index += W * H * np.arange(0, K) + index = index.astype(int) + keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1) + # unsqueeze the instance dimension for single-instance results + keypoints = keypoints[None].astype(np.float32) + scores = scores[None] + + W, H = self.heatmap_size + keypoints = keypoints / [W - 1, H - 1] * self.input_size + + return keypoints, scores diff --git a/mmpose/codecs/utils/__init__.py b/mmpose/codecs/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa093f12bc6d69748509721b4b481811583339c --- /dev/null +++ b/mmpose/codecs/utils/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .gaussian_heatmap import (generate_gaussian_heatmaps, + generate_udp_gaussian_heatmaps, + generate_unbiased_gaussian_heatmaps) +from .instance_property import (get_diagonal_lengths, get_instance_bbox, + get_instance_root) +from .offset_heatmap import (generate_displacement_heatmap, + generate_offset_heatmap) +from .post_processing import (batch_heatmap_nms, gaussian_blur, + gaussian_blur1d, get_heatmap_maximum, + get_simcc_maximum, get_simcc_normalized) +from .refinement import (refine_keypoints, refine_keypoints_dark, + refine_keypoints_dark_udp, refine_simcc_dark) + +__all__ = [ + 'generate_gaussian_heatmaps', 'generate_udp_gaussian_heatmaps', + 'generate_unbiased_gaussian_heatmaps', 'gaussian_blur', + 'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap', + 'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark', + 'refine_keypoints_dark_udp', 'generate_displacement_heatmap', + 'refine_simcc_dark', 'gaussian_blur1d', 'get_diagonal_lengths', + 'get_instance_root', 'get_instance_bbox', 'get_simcc_normalized' +] diff --git a/mmpose/codecs/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf9afad4dbcabd24bbe2146f2dcfc482498dd53e Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/codecs/utils/__pycache__/gaussian_heatmap.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/gaussian_heatmap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..851accf31e4c11ea2a4f7fa703f01d8837b1880b Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/gaussian_heatmap.cpython-38.pyc differ diff --git a/mmpose/codecs/utils/__pycache__/instance_property.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/instance_property.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b80bd7e07b2d49c3760825295676502eb4ddf5d Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/instance_property.cpython-38.pyc differ diff --git a/mmpose/codecs/utils/__pycache__/offset_heatmap.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/offset_heatmap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efa9e9e323b38f7dad803d265cfac7171517c794 Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/offset_heatmap.cpython-38.pyc differ diff --git a/mmpose/codecs/utils/__pycache__/post_processing.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/post_processing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c358fd906e0a050246a7839ff1e2cbfe9362658 Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/post_processing.cpython-38.pyc differ diff --git a/mmpose/codecs/utils/__pycache__/refinement.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/refinement.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c931f875f52229b1749016d40991bcc1d8b0dcb Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/refinement.cpython-38.pyc differ diff --git a/mmpose/codecs/utils/gaussian_heatmap.py b/mmpose/codecs/utils/gaussian_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..91e08c2cdd41ddcb751e3afd13555731731a0602 --- /dev/null +++ b/mmpose/codecs/utils/gaussian_heatmap.py @@ -0,0 +1,227 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import Tuple, Union + +import numpy as np + + +def generate_gaussian_heatmaps( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + sigma: Union[float, Tuple[float], np.ndarray], +) -> Tuple[np.ndarray, np.ndarray]: + """Generate gaussian heatmaps of keypoints. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + sigma (float or List[float]): A list of sigma values of the Gaussian + heatmap for each instance. If sigma is given as a single float + value, it will be expanded into a tuple + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + if isinstance(sigma, (int, float)): + sigma = (sigma, ) * N + + for n in range(N): + # 3-sigma rule + radius = sigma[n] * 3 + + # xy grid + gaussian_size = 2 * radius + 1 + x = np.arange(0, gaussian_size, 1, dtype=np.float32) + y = x[:, None] + x0 = y0 = gaussian_size // 2 + + for k in range(K): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + # get gaussian center coordinates + mu = (keypoints[n, k] + 0.5).astype(np.int64) + + # check that the gaussian has in-bounds part + left, top = (mu - radius).astype(np.int64) + right, bottom = (mu + radius + 1).astype(np.int64) + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + # The gaussian is not normalized, + # we want the center value to equal 1 + gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma[n]**2)) + + # valid range in gaussian + g_x1 = max(0, -left) + g_x2 = min(W, right) - left + g_y1 = max(0, -top) + g_y2 = min(H, bottom) - top + + # valid range in heatmap + h_x1 = max(0, left) + h_x2 = min(W, right) + h_y1 = max(0, top) + h_y2 = min(H, bottom) + + heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2] + gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2] + + _ = np.maximum( + heatmap_region, gaussian_regsion, out=heatmap_region) + + return heatmaps, keypoint_weights + + +def generate_unbiased_gaussian_heatmaps( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + sigma: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate gaussian heatmaps of keypoints using `Dark Pose`_. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + + .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + # 3-sigma rule + radius = sigma * 3 + + # xy grid + x = np.arange(0, W, 1, dtype=np.float32) + y = np.arange(0, H, 1, dtype=np.float32)[:, None] + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + mu = keypoints[n, k] + # check that the gaussian has in-bounds part + left, top = mu - radius + right, bottom = mu + radius + 1 + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + gaussian = np.exp(-((x - mu[0])**2 + (y - mu[1])**2) / (2 * sigma**2)) + + _ = np.maximum(gaussian, heatmaps[k], out=heatmaps[k]) + + return heatmaps, keypoint_weights + + +def generate_udp_gaussian_heatmaps( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + sigma: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate gaussian heatmaps of keypoints using `UDP`_. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + sigma (float): The sigma value of the Gaussian heatmap + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + # 3-sigma rule + radius = sigma * 3 + + # xy grid + gaussian_size = 2 * radius + 1 + x = np.arange(0, gaussian_size, 1, dtype=np.float32) + y = x[:, None] + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + mu = (keypoints[n, k] + 0.5).astype(np.int64) + # check that the gaussian has in-bounds part + left, top = (mu - radius).astype(np.int64) + right, bottom = (mu + radius + 1).astype(np.int64) + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + mu_ac = keypoints[n, k] + x0 = y0 = gaussian_size // 2 + x0 += mu_ac[0] - mu[0] + y0 += mu_ac[1] - mu[1] + gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) + + # valid range in gaussian + g_x1 = max(0, -left) + g_x2 = min(W, right) - left + g_y1 = max(0, -top) + g_y2 = min(H, bottom) - top + + # valid range in heatmap + h_x1 = max(0, left) + h_x2 = min(W, right) + h_y1 = max(0, top) + h_y2 = min(H, bottom) + + heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2] + gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2] + + _ = np.maximum(heatmap_region, gaussian_regsion, out=heatmap_region) + + return heatmaps, keypoint_weights diff --git a/mmpose/codecs/utils/instance_property.py b/mmpose/codecs/utils/instance_property.py new file mode 100644 index 0000000000000000000000000000000000000000..15ae30aef021939e2f0dbf276ce8b1c3cceaa40e --- /dev/null +++ b/mmpose/codecs/utils/instance_property.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np + + +def get_instance_root(keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + root_type: str = 'kpt_center') -> np.ndarray: + """Calculate the coordinates and visibility of instance roots. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + root_type (str): Calculation of instance roots which should + be one of the following options: + + - ``'kpt_center'``: The roots' coordinates are the mean + coordinates of visible keypoints + - ``'bbox_center'``: The roots' are the center of bounding + boxes outlined by visible keypoints + + Defaults to ``'kpt_center'`` + + Returns: + tuple + - roots_coordinate(np.ndarray): Coordinates of instance roots in + shape [N, D] + - roots_visible(np.ndarray): Visibility of instance roots in + shape [N] + """ + + roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32) + roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2 + + for i in range(keypoints.shape[0]): + + # collect visible keypoints + if keypoints_visible is not None: + visible_keypoints = keypoints[i][keypoints_visible[i] > 0] + else: + visible_keypoints = keypoints[i] + if visible_keypoints.size == 0: + roots_visible[i] = 0 + continue + + # compute the instance root with visible keypoints + if root_type == 'kpt_center': + roots_coordinate[i] = visible_keypoints.mean(axis=0) + roots_visible[i] = 1 + elif root_type == 'bbox_center': + roots_coordinate[i] = (visible_keypoints.max(axis=0) + + visible_keypoints.min(axis=0)) / 2.0 + roots_visible[i] = 1 + else: + raise ValueError( + f'the value of `root_type` must be \'kpt_center\' or ' + f'\'bbox_center\', but got \'{root_type}\'') + + return roots_coordinate, roots_visible + + +def get_instance_bbox(keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> np.ndarray: + """Calculate the pseudo instance bounding box from visible keypoints. The + bounding boxes are in the xyxy format. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + np.ndarray: bounding boxes in [N, 4] + """ + bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32) + for i in range(keypoints.shape[0]): + if keypoints_visible is not None: + visible_keypoints = keypoints[i][keypoints_visible[i] > 0] + else: + visible_keypoints = keypoints[i] + if visible_keypoints.size == 0: + continue + + bbox[i, :2] = visible_keypoints.min(axis=0) + bbox[i, 2:] = visible_keypoints.max(axis=0) + return bbox + + +def get_diagonal_lengths(keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> np.ndarray: + """Calculate the diagonal length of instance bounding box from visible + keypoints. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + np.ndarray: bounding box diagonal length in [N] + """ + pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible) + pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2) + h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0] + diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1)) + + return diagonal_length diff --git a/mmpose/codecs/utils/offset_heatmap.py b/mmpose/codecs/utils/offset_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..c3c1c32ed391982fa0f8cd31b6240363b4fe1c52 --- /dev/null +++ b/mmpose/codecs/utils/offset_heatmap.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import Tuple + +import numpy as np + + +def generate_offset_heatmap( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + radius_factor: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate offset heatmaps of keypoints, where each keypoint is + represented by 3 maps: one pixel-level class label map (1 for keypoint and + 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions + respectively. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + radius_factor (float): The radius factor of the binary label + map. The positive region is defined as the neighbor of the + keypoint with the radius :math:`r=radius_factor*max(W, H)` + + Returns: + tuple: + - heatmap (np.ndarray): The generated heatmap in shape + (K*3, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (K,) + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, 3, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + # xy grid + x = np.arange(0, W, 1) + y = np.arange(0, H, 1)[:, None] + + # positive area radius in the classification map + radius = radius_factor * max(W, H) + + for n, k in product(range(N), range(K)): + if keypoints_visible[n, k] < 0.5: + continue + + mu = keypoints[n, k] + + x_offset = (mu[0] - x) / radius + y_offset = (mu[1] - y) / radius + + heatmaps[k, 0] = np.where(x_offset**2 + y_offset**2 <= 1, 1., 0.) + heatmaps[k, 1] = x_offset + heatmaps[k, 2] = y_offset + + heatmaps = heatmaps.reshape(K * 3, H, W) + + return heatmaps, keypoint_weights + + +def generate_displacement_heatmap( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + roots: np.ndarray, + roots_visible: np.ndarray, + diagonal_lengths: np.ndarray, + radius: float, +): + """Generate displacement heatmaps of keypoints, where each keypoint is + represented by 3 maps: one pixel-level class label map (1 for keypoint and + 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions + respectively. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + roots (np.ndarray): Coordinates of instance centers in shape (N, D). + The displacement fields of each instance will locate around its + center. + roots_visible (np.ndarray): Roots visibilities in shape (N,) + diagonal_lengths (np.ndarray): Diaginal length of the bounding boxes + of each instance in shape (N,) + radius (float): The radius factor of the binary label + map. The positive region is defined as the neighbor of the + keypoint with the radius :math:`r=radius_factor*max(W, H)` + + Returns: + tuple: + - displacements (np.ndarray): The generated displacement map in + shape (K*2, H, W) where [W, H] is the `heatmap_size` + - displacement_weights (np.ndarray): The target weights in shape + (K*2, H, W) + """ + N, K, _ = keypoints.shape + W, H = heatmap_size + + displacements = np.zeros((K * 2, H, W), dtype=np.float32) + displacement_weights = np.zeros((K * 2, H, W), dtype=np.float32) + instance_size_map = np.zeros((H, W), dtype=np.float32) + + for n in range(N): + if (roots_visible[n] < 1 or (roots[n, 0] < 0 or roots[n, 1] < 0) + or (roots[n, 0] >= W or roots[n, 1] >= H)): + continue + + diagonal_length = diagonal_lengths[n] + + for k in range(K): + if keypoints_visible[n, k] < 1 or keypoints[n, k, 0] < 0 \ + or keypoints[n, k, 1] < 0 or keypoints[n, k, 0] >= W \ + or keypoints[n, k, 1] >= H: + continue + + start_x = max(int(roots[n, 0] - radius), 0) + start_y = max(int(roots[n, 1] - radius), 0) + end_x = min(int(roots[n, 0] + radius), W) + end_y = min(int(roots[n, 1] + radius), H) + + for x in range(start_x, end_x): + for y in range(start_y, end_y): + if displacements[2 * k, y, + x] != 0 or displacements[2 * k + 1, y, + x] != 0: + if diagonal_length > instance_size_map[y, x]: + # keep the gt displacement of smaller instance + continue + + displacement_weights[2 * k:2 * k + 2, y, + x] = 1 / diagonal_length + displacements[2 * k:2 * k + 2, y, + x] = keypoints[n, k] - [x, y] + instance_size_map[y, x] = diagonal_length + + return displacements, displacement_weights diff --git a/mmpose/codecs/utils/post_processing.py b/mmpose/codecs/utils/post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..75356388dc408d8dda0a72324aa16c3b4f3b6068 --- /dev/null +++ b/mmpose/codecs/utils/post_processing.py @@ -0,0 +1,227 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import Tuple + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +def get_simcc_normalized(batch_pred_simcc, sigma=None): + """Normalize the predicted SimCC. + + Args: + batch_pred_simcc (torch.Tensor): The predicted SimCC. + sigma (float): The sigma of the Gaussian distribution. + + Returns: + torch.Tensor: The normalized SimCC. + """ + B, K, _ = batch_pred_simcc.shape + + # Scale and clamp the tensor + if sigma is not None: + batch_pred_simcc = batch_pred_simcc / (sigma * np.sqrt(np.pi * 2)) + batch_pred_simcc = batch_pred_simcc.clamp(min=0) + + # Compute the binary mask + mask = (batch_pred_simcc.amax(dim=-1) > 1).reshape(B, K, 1) + + # Normalize the tensor using the maximum value + norm = (batch_pred_simcc / batch_pred_simcc.amax(dim=-1).reshape(B, K, 1)) + + # Apply normalization + batch_pred_simcc = torch.where(mask, norm, batch_pred_simcc) + + return batch_pred_simcc + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + + assert isinstance(simcc_x, np.ndarray), ('simcc_x should be numpy.ndarray') + assert isinstance(simcc_y, np.ndarray), ('simcc_y should be numpy.ndarray') + assert simcc_x.ndim == 2 or simcc_x.ndim == 3, ( + f'Invalid shape {simcc_x.shape}') + assert simcc_y.ndim == 2 or simcc_y.ndim == 3, ( + f'Invalid shape {simcc_y.shape}') + assert simcc_x.ndim == simcc_y.ndim, ( + f'{simcc_x.shape} != {simcc_y.shape}') + + if simcc_x.ndim == 3: + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + else: + N = None + + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + if N: + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def get_heatmap_maximum(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from heatmaps. + + Note: + batch_size: B + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (B, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (B, K) + """ + assert isinstance(heatmaps, + np.ndarray), ('heatmaps should be numpy.ndarray') + assert heatmaps.ndim == 3 or heatmaps.ndim == 4, ( + f'Invalid shape {heatmaps.shape}') + + if heatmaps.ndim == 3: + K, H, W = heatmaps.shape + B = None + heatmaps_flatten = heatmaps.reshape(K, -1) + else: + B, K, H, W = heatmaps.shape + heatmaps_flatten = heatmaps.reshape(B * K, -1) + + y_locs, x_locs = np.unravel_index( + np.argmax(heatmaps_flatten, axis=1), shape=(H, W)) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + vals = np.amax(heatmaps_flatten, axis=1) + locs[vals <= 0.] = -1 + + if B: + locs = locs.reshape(B, K, 2) + vals = vals.reshape(B, K) + + return locs, vals + + +def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray: + """Modulate heatmap distribution with Gaussian. + + Note: + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[K, H, W]): model predicted heatmaps. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([K, H, W]): Modulated heatmap distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + K, H, W = heatmaps.shape + + for k in range(K): + origin_max = np.max(heatmaps[k]) + dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = heatmaps[k].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmaps[k] = dr[border:-border, border:-border].copy() + heatmaps[k] *= origin_max / np.max(heatmaps[k]) + return heatmaps + + +def gaussian_blur1d(simcc: np.ndarray, kernel: int = 11) -> np.ndarray: + """Modulate simcc distribution with Gaussian. + + Note: + - num_keypoints: K + - simcc length: Wx + + Args: + simcc (np.ndarray[K, Wx]): model predicted simcc. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the simcc gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([K, Wx]): Modulated simcc distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + N, K, Wx = simcc.shape + + for n, k in product(range(N), range(K)): + origin_max = np.max(simcc[n, k]) + dr = np.zeros((1, Wx + 2 * border), dtype=np.float32) + dr[0, border:-border] = simcc[n, k].copy() + dr = cv2.GaussianBlur(dr, (kernel, 1), 0) + simcc[n, k] = dr[0, border:-border].copy() + simcc[n, k] *= origin_max / np.max(simcc[n, k]) + return simcc + + +def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5): + """Apply NMS on a batch of heatmaps. + + Args: + batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W) + kernel_size (int): The kernel size of the NMS which should be + a odd integer. Defaults to 5 + + Returns: + Tensor: The batch heatmaps after NMS. + """ + + assert isinstance(kernel_size, int) and kernel_size % 2 == 1, \ + f'The kernel_size should be an odd integer, got {kernel_size}' + + padding = (kernel_size - 1) // 2 + + maximum = F.max_pool2d( + batch_heatmaps, kernel_size, stride=1, padding=padding) + maximum_indicator = torch.eq(batch_heatmaps, maximum) + batch_heatmaps = batch_heatmaps * maximum_indicator.float() + + return batch_heatmaps diff --git a/mmpose/codecs/utils/refinement.py b/mmpose/codecs/utils/refinement.py new file mode 100644 index 0000000000000000000000000000000000000000..3495f37d0acf62f860b76609e4ddfd35737c22d8 --- /dev/null +++ b/mmpose/codecs/utils/refinement.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product + +import numpy as np + +from .post_processing import gaussian_blur, gaussian_blur1d + + +def refine_keypoints(keypoints: np.ndarray, + heatmaps: np.ndarray) -> np.ndarray: + """Refine keypoint predictions by moving from the maximum towards the + second maximum by 0.25 pixel. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + for n, k in product(range(N), range(K)): + x, y = keypoints[n, k, :2].astype(int) + + if 1 < x < W - 1 and 0 < y < H: + dx = heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1] + else: + dx = 0. + + if 1 < y < H - 1 and 0 < x < W: + dy = heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x] + else: + dy = 0. + + keypoints[n, k] += np.sign([dx, dy], dtype=np.float32) * 0.25 + + return keypoints + + +def refine_keypoints_dark(keypoints: np.ndarray, heatmaps: np.ndarray, + blur_kernel_size: int) -> np.ndarray: + """Refine keypoint predictions using distribution aware coordinate + decoding. See `Dark Pose`_ for details. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + # modulate heatmaps + heatmaps = gaussian_blur(heatmaps, blur_kernel_size) + np.maximum(heatmaps, 1e-10, heatmaps) + np.log(heatmaps, heatmaps) + + for n, k in product(range(N), range(K)): + x, y = keypoints[n, k, :2].astype(int) + if 1 < x < W - 2 and 1 < y < H - 2: + dx = 0.5 * (heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1]) + dy = 0.5 * (heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x]) + + dxx = 0.25 * ( + heatmaps[k, y, x + 2] - 2 * heatmaps[k, y, x] + + heatmaps[k, y, x - 2]) + dxy = 0.25 * ( + heatmaps[k, y + 1, x + 1] - heatmaps[k, y - 1, x + 1] - + heatmaps[k, y + 1, x - 1] + heatmaps[k, y - 1, x - 1]) + dyy = 0.25 * ( + heatmaps[k, y + 2, x] - 2 * heatmaps[k, y, x] + + heatmaps[k, y - 2, x]) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + keypoints[n, k, :2] += offset + return keypoints + + +def refine_keypoints_dark_udp(keypoints: np.ndarray, heatmaps: np.ndarray, + blur_kernel_size: int) -> np.ndarray: + """Refine keypoint predictions using distribution aware coordinate decoding + for UDP. See `UDP`_ for details. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + # modulate heatmaps + heatmaps = gaussian_blur(heatmaps, blur_kernel_size) + np.clip(heatmaps, 1e-3, 50., heatmaps) + np.log(heatmaps, heatmaps) + + heatmaps_pad = np.pad( + heatmaps, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() + + for n in range(N): + index = keypoints[n, :, 0] + 1 + (keypoints[n, :, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = heatmaps_pad[index] + ix1 = heatmaps_pad[index + 1] + iy1 = heatmaps_pad[index + W + 2] + ix1y1 = heatmaps_pad[index + W + 3] + ix1_y1_ = heatmaps_pad[index - W - 3] + ix1_ = heatmaps_pad[index - 1] + iy1_ = heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(K, 2, 1) + + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints[n] -= np.einsum('imn,ink->imk', hessian, + derivative).squeeze() + + return keypoints + + +def refine_simcc_dark(keypoints: np.ndarray, simcc: np.ndarray, + blur_kernel_size: int) -> np.ndarray: + """SimCC version. Refine keypoint predictions using distribution aware + coordinate decoding for UDP. See `UDP`_ for details. The operation is in- + place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + simcc (np.ndarray): The heatmaps in shape (N, K, Wx) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + N = simcc.shape[0] + + # modulate simcc + simcc = gaussian_blur1d(simcc, blur_kernel_size) + np.clip(simcc, 1e-3, 50., simcc) + np.log(simcc, simcc) + + simcc = np.pad(simcc, ((0, 0), (0, 0), (2, 2)), 'edge') + + for n in range(N): + px = (keypoints[n] + 2.5).astype(np.int64).reshape(-1, 1) # K, 1 + + dx0 = np.take_along_axis(simcc[n], px, axis=1) # K, 1 + dx1 = np.take_along_axis(simcc[n], px + 1, axis=1) + dx_1 = np.take_along_axis(simcc[n], px - 1, axis=1) + dx2 = np.take_along_axis(simcc[n], px + 2, axis=1) + dx_2 = np.take_along_axis(simcc[n], px - 2, axis=1) + + dx = 0.5 * (dx1 - dx_1) + dxx = 1e-9 + 0.25 * (dx2 - 2 * dx0 + dx_2) + + offset = dx / dxx + keypoints[n] -= offset.reshape(-1) + + return keypoints diff --git a/mmpose/datasets/__init__.py b/mmpose/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b90a12db4937ffca9ff103b1e5a0c7604de52e0b --- /dev/null +++ b/mmpose/datasets/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import build_dataset +from .dataset_wrappers import CombinedDataset +from .datasets import * # noqa +from .samplers import MultiSourceSampler +from .transforms import * # noqa + +__all__ = ['build_dataset', 'CombinedDataset', 'MultiSourceSampler'] diff --git a/mmpose/datasets/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14ad2c4a8c94b4ef8fbd4f2f555faf4e87ebb6ce Binary files /dev/null and b/mmpose/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/__pycache__/builder.cpython-38.pyc b/mmpose/datasets/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd5c0867baf23dc53275361f3623148f3c0c1656 Binary files /dev/null and b/mmpose/datasets/__pycache__/builder.cpython-38.pyc differ diff --git a/mmpose/datasets/__pycache__/dataset_wrappers.cpython-38.pyc b/mmpose/datasets/__pycache__/dataset_wrappers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd82e726c759c40f56013c0a96a63f178bb3e06a Binary files /dev/null and b/mmpose/datasets/__pycache__/dataset_wrappers.cpython-38.pyc differ diff --git a/mmpose/datasets/__pycache__/samplers.cpython-38.pyc b/mmpose/datasets/__pycache__/samplers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b758d8b51ab4fb08b319f36d3542c84f27711646 Binary files /dev/null and b/mmpose/datasets/__pycache__/samplers.cpython-38.pyc differ diff --git a/mmpose/datasets/builder.py b/mmpose/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5a236ff49b70b86149d318cbccdfd5af5a6450 --- /dev/null +++ b/mmpose/datasets/builder.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import platform +import random + +import numpy as np +import torch +from mmengine import build_from_cfg, is_seq_of +from mmengine.dataset import ConcatDataset, RepeatDataset + +from mmpose.registry import DATASETS + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + base_soft_limit = rlimit[0] + hard_limit = rlimit[1] + soft_limit = min(max(4096, base_soft_limit), hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + + +def _concat_dataset(cfg, default_args=None): + types = cfg['type'] + ann_files = cfg['ann_file'] + img_prefixes = cfg.get('img_prefix', None) + dataset_infos = cfg.get('dataset_info', None) + + num_joints = cfg['data_cfg'].get('num_joints', None) + dataset_channel = cfg['data_cfg'].get('dataset_channel', None) + + datasets = [] + num_dset = len(ann_files) + for i in range(num_dset): + cfg_copy = copy.deepcopy(cfg) + cfg_copy['ann_file'] = ann_files[i] + + if isinstance(types, (list, tuple)): + cfg_copy['type'] = types[i] + if isinstance(img_prefixes, (list, tuple)): + cfg_copy['img_prefix'] = img_prefixes[i] + if isinstance(dataset_infos, (list, tuple)): + cfg_copy['dataset_info'] = dataset_infos[i] + + if isinstance(num_joints, (list, tuple)): + cfg_copy['data_cfg']['num_joints'] = num_joints[i] + + if is_seq_of(dataset_channel, list): + cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i] + + datasets.append(build_dataset(cfg_copy, default_args)) + + return ConcatDataset(datasets) + + +def build_dataset(cfg, default_args=None): + """Build a dataset from config dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + default_args (dict, optional): Default initialization arguments. + Default: None. + + Returns: + Dataset: The constructed dataset. + """ + + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'ConcatDataset': + dataset = ConcatDataset( + [build_dataset(c, default_args) for c in cfg['datasets']]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif isinstance(cfg.get('ann_file'), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + return dataset + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Init the random seed for various workers.""" + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..28eeac9945199c08d0de89b5348511c3caae790d --- /dev/null +++ b/mmpose/datasets/dataset_wrappers.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from copy import deepcopy +from typing import Any, Callable, List, Tuple, Union + +from mmengine.dataset import BaseDataset +from mmengine.registry import build_from_cfg + +from mmpose.registry import DATASETS +from .datasets.utils import parse_pose_metainfo + + +@DATASETS.register_module() +class CombinedDataset(BaseDataset): + """A wrapper of combined dataset. + + Args: + metainfo (dict): The meta information of combined dataset. + datasets (list): The configs of datasets to be combined. + pipeline (list, optional): Processing pipeline. Defaults to []. + """ + + def __init__(self, + metainfo: dict, + datasets: list, + pipeline: List[Union[dict, Callable]] = [], + **kwargs): + + self.datasets = [] + + for cfg in datasets: + dataset = build_from_cfg(cfg, DATASETS) + self.datasets.append(dataset) + + self._lens = [len(dataset) for dataset in self.datasets] + self._len = sum(self._lens) + + super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs) + self._metainfo = parse_pose_metainfo(metainfo) + + @property + def metainfo(self): + return deepcopy(self._metainfo) + + def __len__(self): + return self._len + + def _get_subset_index(self, index: int) -> Tuple[int, int]: + """Given a data sample's global index, return the index of the sub- + dataset the data sample belongs to, and the local index within that + sub-dataset. + + Args: + index (int): The global data sample index + + Returns: + tuple[int, int]: + - subset_index (int): The index of the sub-dataset + - local_index (int): The index of the data sample within + the sub-dataset + """ + if index >= len(self) or index < -len(self): + raise ValueError( + f'index({index}) is out of bounds for dataset with ' + f'length({len(self)}).') + + if index < 0: + index = index + len(self) + + subset_index = 0 + while index >= self._lens[subset_index]: + index -= self._lens[subset_index] + subset_index += 1 + return subset_index, index + + def prepare_data(self, idx: int) -> Any: + """Get data processed by ``self.pipeline``.The source dataset is + depending on the index. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + + data_info = self.get_data_info(idx) + + return self.pipeline(data_info) + + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``CombinedDataset``. + Returns: + dict: The idx-th annotation of the datasets. + """ + subset_idx, sample_idx = self._get_subset_index(idx) + # Get data sample processed by ``subset.pipeline`` + data_info = self.datasets[subset_idx][sample_idx] + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + 'upper_body_ids', 'lower_body_ids', 'flip_pairs', + 'dataset_keypoint_weights', 'flip_indices' + ] + + for key in metainfo_keys: + data_info[key] = deepcopy(self._metainfo[key]) + + return data_info + + def full_init(self): + """Fully initialize all sub datasets.""" + + if self._fully_initialized: + return + + for dataset in self.datasets: + dataset.full_init() + self._fully_initialized = True diff --git a/mmpose/datasets/datasets/__init__.py b/mmpose/datasets/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03a0f493ca197fe2d6004ed0eb91af1d3e693524 --- /dev/null +++ b/mmpose/datasets/datasets/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .animal import * # noqa: F401, F403 +from .base import * # noqa: F401, F403 +from .body import * # noqa: F401, F403 +from .face import * # noqa: F401, F403 +from .fashion import * # noqa: F401, F403 +from .hand import * # noqa: F401, F403 +from .wholebody import * # noqa: F401, F403 diff --git a/mmpose/datasets/datasets/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b506b9edca6cc639a78b51e976890787716250e0 Binary files /dev/null and b/mmpose/datasets/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/__pycache__/utils.cpython-38.pyc b/mmpose/datasets/datasets/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e472c408ae95f858609ecc38664dacdaaf88b9e Binary files /dev/null and b/mmpose/datasets/datasets/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__init__.py b/mmpose/datasets/datasets/animal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe9b5938c6c16241727922ef1e4e8c91eb058aa --- /dev/null +++ b/mmpose/datasets/datasets/animal/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .animalpose_dataset import AnimalPoseDataset +from .ap10k_dataset import AP10KDataset +from .atrw_dataset import ATRWDataset +from .fly_dataset import FlyDataset +from .horse10_dataset import Horse10Dataset +from .locust_dataset import LocustDataset +from .macaque_dataset import MacaqueDataset +from .zebra_dataset import ZebraDataset + +__all__ = [ + 'AnimalPoseDataset', 'AP10KDataset', 'Horse10Dataset', 'MacaqueDataset', + 'FlyDataset', 'LocustDataset', 'ZebraDataset', 'ATRWDataset' +] diff --git a/mmpose/datasets/datasets/animal/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df0a3bdcdade8eea3dab047352a47b0bc82b3acd Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/animalpose_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/animalpose_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda9277b9a423ee27ce709ea4369c3170560e978 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/animalpose_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/ap10k_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/ap10k_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..879e1a96fbafdb61eb0fb816222efbf05badecc0 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/ap10k_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/atrw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/atrw_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92149fa8b6a61770d8bca4486ae12467d4c967d5 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/atrw_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/fly_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/fly_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..579409003837ce6c27375d159242bfce39ef1a78 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/fly_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/horse10_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/horse10_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79780b41be5a7125fd72e0e629437e2805b39451 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/horse10_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/locust_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/locust_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..256924df7a26dc2ba4dabb9336b322bafbc3fee1 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/locust_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/macaque_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/macaque_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4efa67a70c8ee2f9b7a1b87f35c5f9642d650f59 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/macaque_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/__pycache__/zebra_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/zebra_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df79f7f8042dd4aefcc3e7c7aace026769cf2396 Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/zebra_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/animal/animalpose_dataset.py b/mmpose/datasets/datasets/animal/animalpose_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0279cf9de0907626f2a6686170dc5e99aafa2d9d --- /dev/null +++ b/mmpose/datasets/datasets/animal/animalpose_dataset.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class AnimalPoseDataset(BaseCocoStyleDataset): + """Animal-Pose dataset for animal pose estimation. + + "Cross-domain Adaptation For Animal Pose Estimation" ICCV'2019 + More details can be found in the `paper + `__ . + + Animal-Pose keypoints:: + + 0: 'L_Eye', + 1: 'R_Eye', + 2: 'L_EarBase', + 3: 'R_EarBase', + 4: 'Nose', + 5: 'Throat', + 6: 'TailBase', + 7: 'Withers', + 8: 'L_F_Elbow', + 9: 'R_F_Elbow', + 10: 'L_B_Elbow', + 11: 'R_B_Elbow', + 12: 'L_F_Knee', + 13: 'R_F_Knee', + 14: 'L_B_Knee', + 15: 'R_B_Knee', + 16: 'L_F_Paw', + 17: 'R_F_Paw', + 18: 'L_B_Paw', + 19: 'R_B_Paw' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/animalpose.py') diff --git a/mmpose/datasets/datasets/animal/ap10k_dataset.py b/mmpose/datasets/datasets/animal/ap10k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de1efbc67f7be55c57532684174442a3f865d5fd --- /dev/null +++ b/mmpose/datasets/datasets/animal/ap10k_dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class AP10KDataset(BaseCocoStyleDataset): + """AP-10K dataset for animal pose estimation. + + "AP-10K: A Benchmark for Animal Pose Estimation in the Wild" + Neurips Dataset Track'2021. + More details can be found in the `paper + `__ . + + AP-10K keypoints:: + + 0: 'L_Eye', + 1: 'R_Eye', + 2: 'Nose', + 3: 'Neck', + 4: 'root of tail', + 5: 'L_Shoulder', + 6: 'L_Elbow', + 7: 'L_F_Paw', + 8: 'R_Shoulder', + 9: 'R_Elbow', + 10: 'R_F_Paw, + 11: 'L_Hip', + 12: 'L_Knee', + 13: 'L_B_Paw', + 14: 'R_Hip', + 15: 'R_Knee', + 16: 'R_B_Paw' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/ap10k.py') diff --git a/mmpose/datasets/datasets/animal/atrw_dataset.py b/mmpose/datasets/datasets/animal/atrw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de5b1a09a0510969ea0a6d57c15e5bd13104b99b --- /dev/null +++ b/mmpose/datasets/datasets/animal/atrw_dataset.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class ATRWDataset(BaseCocoStyleDataset): + """ATRW dataset for animal pose estimation. + + "ATRW: A Benchmark for Amur Tiger Re-identification in the Wild" + ACM MM'2020. + More details can be found in the `paper + `__ . + + ATRW keypoints:: + + 0: "left_ear", + 1: "right_ear", + 2: "nose", + 3: "right_shoulder", + 4: "right_front_paw", + 5: "left_shoulder", + 6: "left_front_paw", + 7: "right_hip", + 8: "right_knee", + 9: "right_back_paw", + 10: "left_hip", + 11: "left_knee", + 12: "left_back_paw", + 13: "tail", + 14: "center" + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/atrw.py') diff --git a/mmpose/datasets/datasets/animal/fly_dataset.py b/mmpose/datasets/datasets/animal/fly_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b614d9b9f77b1e2eb7f067ea6cfb21d788857554 --- /dev/null +++ b/mmpose/datasets/datasets/animal/fly_dataset.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class FlyDataset(BaseCocoStyleDataset): + """FlyDataset for animal pose estimation. + + "Fast animal pose estimation using deep neural networks" + Nature methods'2019. More details can be found in the `paper + `__ . + + Vinegar Fly keypoints:: + + 0: "head", + 1: "eyeL", + 2: "eyeR", + 3: "neck", + 4: "thorax", + 5: "abdomen", + 6: "forelegR1", + 7: "forelegR2", + 8: "forelegR3", + 9: "forelegR4", + 10: "midlegR1", + 11: "midlegR2", + 12: "midlegR3", + 13: "midlegR4", + 14: "hindlegR1", + 15: "hindlegR2", + 16: "hindlegR3", + 17: "hindlegR4", + 18: "forelegL1", + 19: "forelegL2", + 20: "forelegL3", + 21: "forelegL4", + 22: "midlegL1", + 23: "midlegL2", + 24: "midlegL3", + 25: "midlegL4", + 26: "hindlegL1", + 27: "hindlegL2", + 28: "hindlegL3", + 29: "hindlegL4", + 30: "wingL", + 31: "wingR" + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/fly.py') diff --git a/mmpose/datasets/datasets/animal/horse10_dataset.py b/mmpose/datasets/datasets/animal/horse10_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0c25dba6a705045b731bddd176bf20a46c285764 --- /dev/null +++ b/mmpose/datasets/datasets/animal/horse10_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class Horse10Dataset(BaseCocoStyleDataset): + """Horse10Dataset for animal pose estimation. + + "Pretraining boosts out-of-domain robustness for pose estimation" + WACV'2021. More details can be found in the `paper + `__ . + + Horse-10 keypoints:: + + 0: 'Nose', + 1: 'Eye', + 2: 'Nearknee', + 3: 'Nearfrontfetlock', + 4: 'Nearfrontfoot', + 5: 'Offknee', + 6: 'Offfrontfetlock', + 7: 'Offfrontfoot', + 8: 'Shoulder', + 9: 'Midshoulder', + 10: 'Elbow', + 11: 'Girth', + 12: 'Wither', + 13: 'Nearhindhock', + 14: 'Nearhindfetlock', + 15: 'Nearhindfoot', + 16: 'Hip', + 17: 'Stifle', + 18: 'Offhindhock', + 19: 'Offhindfetlock', + 20: 'Offhindfoot', + 21: 'Ischium' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/horse10.py') diff --git a/mmpose/datasets/datasets/animal/locust_dataset.py b/mmpose/datasets/datasets/animal/locust_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3ada76034db8e9cbc25d68ccd9a430ea62394c74 --- /dev/null +++ b/mmpose/datasets/datasets/animal/locust_dataset.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class LocustDataset(BaseCocoStyleDataset): + """LocustDataset for animal pose estimation. + + "DeepPoseKit, a software toolkit for fast and robust animal + pose estimation using deep learning" Elife'2019. + More details can be found in the `paper + `__ . + + Desert Locust keypoints:: + + 0: "head", + 1: "neck", + 2: "thorax", + 3: "abdomen1", + 4: "abdomen2", + 5: "anttipL", + 6: "antbaseL", + 7: "eyeL", + 8: "forelegL1", + 9: "forelegL2", + 10: "forelegL3", + 11: "forelegL4", + 12: "midlegL1", + 13: "midlegL2", + 14: "midlegL3", + 15: "midlegL4", + 16: "hindlegL1", + 17: "hindlegL2", + 18: "hindlegL3", + 19: "hindlegL4", + 20: "anttipR", + 21: "antbaseR", + 22: "eyeR", + 23: "forelegR1", + 24: "forelegR2", + 25: "forelegR3", + 26: "forelegR4", + 27: "midlegR1", + 28: "midlegR2", + 29: "midlegR3", + 30: "midlegR4", + 31: "hindlegR1", + 32: "hindlegR2", + 33: "hindlegR3", + 34: "hindlegR4" + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/locust.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw Locust annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # get bbox in shape [1, 4], formatted as xywh + # use the entire image which is 160x160 + bbox = np.array([0, 0, 160, 160], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': ann['num_keypoints'], + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + + return data_info diff --git a/mmpose/datasets/datasets/animal/macaque_dataset.py b/mmpose/datasets/datasets/animal/macaque_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..08da981a1a2299efaadaf727b3960e769999fc35 --- /dev/null +++ b/mmpose/datasets/datasets/animal/macaque_dataset.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class MacaqueDataset(BaseCocoStyleDataset): + """MacaquePose dataset for animal pose estimation. + + "MacaquePose: A novel 'in the wild' macaque monkey pose dataset + for markerless motion capture" bioRxiv'2020. + More details can be found in the `paper + `__ . + + Macaque keypoints:: + + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/macaque.py') diff --git a/mmpose/datasets/datasets/animal/zebra_dataset.py b/mmpose/datasets/datasets/animal/zebra_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b399a8479bcf18b8b33115b4cd703563e1a846d3 --- /dev/null +++ b/mmpose/datasets/datasets/animal/zebra_dataset.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class ZebraDataset(BaseCocoStyleDataset): + """ZebraDataset for animal pose estimation. + + "DeepPoseKit, a software toolkit for fast and robust animal + pose estimation using deep learning" Elife'2019. + More details can be found in the `paper + `__ . + + Zebra keypoints:: + + 0: "snout", + 1: "head", + 2: "neck", + 3: "forelegL1", + 4: "forelegR1", + 5: "hindlegL1", + 6: "hindlegR1", + 7: "tailbase", + 8: "tailtip" + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/zebra.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw Zebra annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # get bbox in shape [1, 4], formatted as xywh + # use the entire image which is 160x160 + bbox = np.array([0, 0, 160, 160], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = ann['num_keypoints'] + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + + return data_info diff --git a/mmpose/datasets/datasets/base/__init__.py b/mmpose/datasets/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23bb4efb48daaf71d4d29e46834251c8da3ebbb9 --- /dev/null +++ b/mmpose/datasets/datasets/base/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_coco_style_dataset import BaseCocoStyleDataset + +__all__ = ['BaseCocoStyleDataset'] diff --git a/mmpose/datasets/datasets/base/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/base/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb67dad34dba321c2ac3f2e884463515bb33473 Binary files /dev/null and b/mmpose/datasets/datasets/base/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/base/__pycache__/base_coco_style_dataset.cpython-38.pyc b/mmpose/datasets/datasets/base/__pycache__/base_coco_style_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2860e8f32b3377bd8d2704d6fa334b71420a7e3 Binary files /dev/null and b/mmpose/datasets/datasets/base/__pycache__/base_coco_style_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/base/base_coco_style_dataset.py b/mmpose/datasets/datasets/base/base_coco_style_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3b592813d8faa614ef26623421a69b497ba3f982 --- /dev/null +++ b/mmpose/datasets/datasets/base/base_coco_style_dataset.py @@ -0,0 +1,458 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from copy import deepcopy +from itertools import filterfalse, groupby +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +from mmengine.dataset import BaseDataset, force_full_init +from mmengine.fileio import exists, get_local_path, load +from mmengine.utils import is_list_of +from xtcocotools.coco import COCO + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_xywh2xyxy +from ..utils import parse_pose_metainfo + + +@DATASETS.register_module() +class BaseCocoStyleDataset(BaseDataset): + """Base class for COCO-style datasets. + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. + Default: ``dict(img='')``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + bbox_file: Optional[str] = None, + data_mode: str = 'topdown', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + + if data_mode not in {'topdown', 'bottomup'}: + raise ValueError( + f'{self.__class__.__name__} got invalid data_mode: ' + f'{data_mode}. Should be "topdown" or "bottomup".') + self.data_mode = data_mode + + if bbox_file: + if self.data_mode != 'topdown': + raise ValueError( + f'{self.__class__.__name__} is set to {self.data_mode}: ' + 'mode, while "bbox_file" is only ' + 'supported in topdown mode.') + + if not test_mode: + raise ValueError( + f'{self.__class__.__name__} has `test_mode==False` ' + 'while "bbox_file" is only ' + 'supported when `test_mode==True`.') + self.bbox_file = bbox_file + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @classmethod + def _load_metainfo(cls, metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + metainfo (dict): Raw data of pose meta information. + + Returns: + dict: Parsed meta information. + """ + + if metainfo is None: + metainfo = deepcopy(cls.METAINFO) + + if not isinstance(metainfo, dict): + raise TypeError( + f'metainfo should be a dict, but got {type(metainfo)}') + + # parse pose metainfo if it has been assigned + if metainfo: + metainfo = parse_pose_metainfo(metainfo) + return metainfo + + @force_full_init + def prepare_data(self, idx) -> Any: + """Get data processed by ``self.pipeline``. + + :class:`BaseCocoStyleDataset` overrides this method from + :class:`mmengine.dataset.BaseDataset` to add the metainfo into + the ``data_info`` before it is passed to the pipeline. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + + return self.pipeline(data_info) + + def get_data_info(self, idx: int) -> dict: + """Get data info by index. + + Args: + idx (int): Index of data info. + + Returns: + dict: Data info. + """ + data_info = super().get_data_info(idx) + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + 'upper_body_ids', 'lower_body_ids', 'flip_pairs', + 'dataset_keypoint_weights', 'flip_indices', 'skeleton_links' + ] + + for key in metainfo_keys: + assert key not in data_info, ( + f'"{key}" is a reserved key for `metainfo`, but already ' + 'exists in the `data_info`.') + + data_info[key] = deepcopy(self._metainfo[key]) + + return data_info + + def load_data_list(self) -> List[dict]: + """Load data list from COCO annotation file or person detection result + file.""" + + if self.bbox_file: + data_list = self._load_detection_results() + else: + instance_list, image_list = self._load_annotations() + + if self.data_mode == 'topdown': + data_list = self._get_topdown_data_infos(instance_list) + else: + data_list = self._get_bottomup_data_infos( + instance_list, image_list) + + return data_list + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + """Load data from annotations in COCO format.""" + + assert exists(self.ann_file), 'Annotation file does not exist' + + with get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + # set the metainfo about categories, which is a list of dict + # and each dict contains the 'id', 'name', etc. about this category + self._metainfo['CLASSES'] = self.coco.loadCats(self.coco.getCatIds()) + + instance_list = [] + image_list = [] + + for img_id in self.coco.getImgIds(): + img = self.coco.loadImgs(img_id)[0] + img.update({ + 'img_id': + img_id, + 'img_path': + osp.join(self.data_prefix['img'], img['file_name']), + }) + image_list.append(img) + + ann_ids = self.coco.getAnnIds(imgIds=img_id) + for ann in self.coco.loadAnns(ann_ids): + + instance_info = self.parse_data_info( + dict(raw_ann_info=ann, raw_img_info=img)) + + # skip invalid instance annotation. + if not instance_info: + continue + + instance_list.append(instance_info) + return instance_list, image_list + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw COCO annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict | None: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + # filter invalid instance + if 'bbox' not in ann or 'keypoints' not in ann: + return None + + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + if 'num_keypoints' in ann: + num_keypoints = ann['num_keypoints'] + else: + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img['img_path'], + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann.get('iscrowd', 0), + 'segmentation': ann.get('segmentation', None), + 'id': ann['id'], + 'category_id': ann['category_id'], + # store the raw annotation of the instance + # it is useful for evaluation without providing ann_file + 'raw_ann_info': copy.deepcopy(ann), + } + + if 'crowdIndex' in img: + data_info['crowd_index'] = img['crowdIndex'] + + return data_info + + @staticmethod + def _is_valid_instance(data_info: Dict) -> bool: + """Check a data info is an instance with valid bbox and keypoint + annotations.""" + # crowd annotation + if 'iscrowd' in data_info and data_info['iscrowd']: + return False + # invalid keypoints + if 'num_keypoints' in data_info and data_info['num_keypoints'] == 0: + return False + # invalid bbox + if 'bbox' in data_info: + bbox = data_info['bbox'][0] + w, h = bbox[2:4] - bbox[:2] + if w <= 0 or h <= 0: + return False + # invalid keypoints + if 'keypoints' in data_info: + if np.max(data_info['keypoints']) <= 0: + return False + return True + + def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]: + """Organize the data list in top-down mode.""" + # sanitize data samples + data_list_tp = list(filter(self._is_valid_instance, instance_list)) + + return data_list_tp + + def _get_bottomup_data_infos(self, instance_list: List[Dict], + image_list: List[Dict]) -> List[Dict]: + """Organize the data list in bottom-up mode.""" + + # bottom-up data list + data_list_bu = [] + + used_img_ids = set() + + # group instances by img_id + for img_id, data_infos in groupby(instance_list, + lambda x: x['img_id']): + used_img_ids.add(img_id) + data_infos = list(data_infos) + + # image data + img_path = data_infos[0]['img_path'] + data_info_bu = { + 'img_id': img_id, + 'img_path': img_path, + } + + for key in data_infos[0].keys(): + if key not in data_info_bu: + seq = [d[key] for d in data_infos] + if isinstance(seq[0], np.ndarray): + seq = np.concatenate(seq, axis=0) + data_info_bu[key] = seq + + # The segmentation annotation of invalid objects will be used + # to generate valid region mask in the pipeline. + invalid_segs = [] + for data_info_invalid in filterfalse(self._is_valid_instance, + data_infos): + if 'segmentation' in data_info_invalid: + invalid_segs.append(data_info_invalid['segmentation']) + data_info_bu['invalid_segs'] = invalid_segs + + data_list_bu.append(data_info_bu) + + # add images without instance for evaluation + if self.test_mode: + for img_info in image_list: + if img_info['img_id'] not in used_img_ids: + data_info_bu = { + 'img_id': img_info['img_id'], + 'img_path': img_info['img_path'], + 'id': list(), + 'raw_ann_info': None, + } + data_list_bu.append(data_info_bu) + + return data_list_bu + + def _load_detection_results(self) -> List[dict]: + """Load data from detection results with dummy keypoint annotations.""" + + assert exists(self.ann_file), 'Annotation file does not exist' + assert exists(self.bbox_file), 'Bbox file does not exist' + # load detection results + det_results = load(self.bbox_file) + assert is_list_of(det_results, dict) + + # load coco annotations to build image id-to-name index + with get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + # set the metainfo about categories, which is a list of dict + # and each dict contains the 'id', 'name', etc. about this category + self._metainfo['CLASSES'] = self.coco.loadCats(self.coco.getCatIds()) + + num_keypoints = self.metainfo['num_keypoints'] + data_list = [] + id_ = 0 + for det in det_results: + # remove non-human instances + if det['category_id'] != 1: + continue + + img = self.coco.loadImgs(det['image_id'])[0] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + bbox_xywh = np.array( + det['bbox'][:4], dtype=np.float32).reshape(1, 4) + bbox = bbox_xywh2xyxy(bbox_xywh) + bbox_score = np.array(det['score'], dtype=np.float32).reshape(1) + + # use dummy keypoint location and visibility + keypoints = np.zeros((1, num_keypoints, 2), dtype=np.float32) + keypoints_visible = np.ones((1, num_keypoints), dtype=np.float32) + + data_list.append({ + 'img_id': det['image_id'], + 'img_path': img_path, + 'img_shape': (img['height'], img['width']), + 'bbox': bbox, + 'bbox_score': bbox_score, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'id': id_, + }) + + id_ += 1 + + return data_list + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. Defaults return full + ``data_list``. + + If 'bbox_score_thr` in filter_cfg, the annotation with bbox_score below + the threshold `bbox_score_thr` will be filtered out. + """ + + data_list = self.data_list + + if self.filter_cfg is None: + return data_list + + # filter out annotations with a bbox_score below the threshold + if 'bbox_score_thr' in self.filter_cfg: + + if self.data_mode != 'topdown': + raise ValueError( + f'{self.__class__.__name__} is set to {self.data_mode} ' + 'mode, while "bbox_score_thr" is only supported in ' + 'topdown mode.') + + thr = self.filter_cfg['bbox_score_thr'] + data_list = list( + filterfalse(lambda ann: ann['bbox_score'] < thr, data_list)) + + return data_list diff --git a/mmpose/datasets/datasets/body/__init__.py b/mmpose/datasets/datasets/body/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4aeef851907f1799bfc1e3507d66c8e7752f32e --- /dev/null +++ b/mmpose/datasets/datasets/body/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .aic_dataset import AicDataset +from .coco_dataset import CocoDataset +from .crowdpose_dataset import CrowdPoseDataset +from .jhmdb_dataset import JhmdbDataset +from .mhp_dataset import MhpDataset +from .mpii_dataset import MpiiDataset +from .mpii_trb_dataset import MpiiTrbDataset +from .ochuman_dataset import OCHumanDataset +from .posetrack18_dataset import PoseTrack18Dataset +from .posetrack18_video_dataset import PoseTrack18VideoDataset + +__all__ = [ + 'CocoDataset', 'MpiiDataset', 'MpiiTrbDataset', 'AicDataset', + 'CrowdPoseDataset', 'OCHumanDataset', 'MhpDataset', 'PoseTrack18Dataset', + 'JhmdbDataset', 'PoseTrack18VideoDataset' +] diff --git a/mmpose/datasets/datasets/body/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9614d92c8bf229e22822c14d7d9474c93623b2aa Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/aic_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/aic_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df2dd0883e58fefb3362125b44e5d3979bbee09 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/aic_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/coco_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/coco_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f09a57dfe92799d8af7dc978baf6a63871bea2e3 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/coco_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/crowdpose_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/crowdpose_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04a5f70ee71f436fa03a9fbb4c7c7e8ff6a02e1 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/crowdpose_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/jhmdb_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/jhmdb_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..655a75a671439353d5d95eb4dda3f36406a973a7 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/jhmdb_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/mhp_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/mhp_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..454497bfa500fdc0953164fbd721fa691d37e8a1 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/mhp_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/mpii_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/mpii_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dab035203d1532b1d6c19cef8f736474c8c9d302 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/mpii_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/mpii_trb_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/mpii_trb_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..437bcb0d22da4272e8461c21a8cf937fceb93ff9 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/mpii_trb_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/ochuman_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/ochuman_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b583aa2a759d453c9357df2b74202b0f71c140 Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/ochuman_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/posetrack18_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/posetrack18_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88f2951e19f33db3472a98d76675be7ac0cbce1a Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/posetrack18_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/__pycache__/posetrack18_video_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/posetrack18_video_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ef9d73814c2fc23c63408e36dbc91f911090edf Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/posetrack18_video_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/body/aic_dataset.py b/mmpose/datasets/datasets/body/aic_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c7cccc76fb47b53cd73f3152878e051b442199 --- /dev/null +++ b/mmpose/datasets/datasets/body/aic_dataset.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class AicDataset(BaseCocoStyleDataset): + """AIC dataset for pose estimation. + + "AI Challenger : A Large-scale Dataset for Going Deeper + in Image Understanding", arXiv'2017. + More details can be found in the `paper + `__ + + AIC keypoints:: + + 0: "right_shoulder", + 1: "right_elbow", + 2: "right_wrist", + 3: "left_shoulder", + 4: "left_elbow", + 5: "left_wrist", + 6: "right_hip", + 7: "right_knee", + 8: "right_ankle", + 9: "left_hip", + 10: "left_knee", + 11: "left_ankle", + 12: "head_top", + 13: "neck" + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/aic.py') diff --git a/mmpose/datasets/datasets/body/coco_dataset.py b/mmpose/datasets/datasets/body/coco_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc971f91f70ba28de1b9ae520d10a2f491eb32b --- /dev/null +++ b/mmpose/datasets/datasets/body/coco_dataset.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class CocoDataset(BaseCocoStyleDataset): + """COCO dataset for pose estimation. + + "Microsoft COCO: Common Objects in Context", ECCV'2014. + More details can be found in the `paper + `__ . + + COCO keypoints:: + + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/coco.py') diff --git a/mmpose/datasets/datasets/body/crowdpose_dataset.py b/mmpose/datasets/datasets/body/crowdpose_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4218708ff27b37dce7992d73695193442207b6d9 --- /dev/null +++ b/mmpose/datasets/datasets/body/crowdpose_dataset.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class CrowdPoseDataset(BaseCocoStyleDataset): + """CrowdPose dataset for pose estimation. + + "CrowdPose: Efficient Crowded Scenes Pose Estimation and + A New Benchmark", CVPR'2019. + More details can be found in the `paper + `__. + + CrowdPose keypoints:: + + 0: 'left_shoulder', + 1: 'right_shoulder', + 2: 'left_elbow', + 3: 'right_elbow', + 4: 'left_wrist', + 5: 'right_wrist', + 6: 'left_hip', + 7: 'right_hip', + 8: 'left_knee', + 9: 'right_knee', + 10: 'left_ankle', + 11: 'right_ankle', + 12: 'top_head', + 13: 'neck' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/crowdpose.py') diff --git a/mmpose/datasets/datasets/body/jhmdb_dataset.py b/mmpose/datasets/datasets/body/jhmdb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7d72a7ddc5129af5e1de144853e5389b09465fd8 --- /dev/null +++ b/mmpose/datasets/datasets/body/jhmdb_dataset.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class JhmdbDataset(BaseCocoStyleDataset): + """JhmdbDataset dataset for pose estimation. + + "Towards understanding action recognition", ICCV'2013. + More details can be found in the `paper + `__ + + sub-JHMDB keypoints:: + + 0: "neck", + 1: "belly", + 2: "head", + 3: "right_shoulder", + 4: "left_shoulder", + 5: "right_hip", + 6: "left_hip", + 7: "right_elbow", + 8: "left_elbow", + 9: "right_knee", + 10: "left_knee", + 11: "right_wrist", + 12: "left_wrist", + 13: "right_ankle", + 14: "left_ankle" + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/jhmdb.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw COCO annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + # JHMDB uses matlab format, index is 1-based, + # we should first convert to 0-based index + x -= 1 + y -= 1 + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + # JHMDB uses matlab format, index is 1-based, + # we should first convert to 0-based index + keypoints = _keypoints[..., :2] - 1 + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann.get('iscrowd', 0), + 'segmentation': ann.get('segmentation', None), + 'id': ann['id'], + } + + return data_info diff --git a/mmpose/datasets/datasets/body/mhp_dataset.py b/mmpose/datasets/datasets/body/mhp_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..55d33602536383898c8b65ca48994d33c1616bea --- /dev/null +++ b/mmpose/datasets/datasets/body/mhp_dataset.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class MhpDataset(BaseCocoStyleDataset): + """MHPv2.0 dataset for pose estimation. + + "Understanding Humans in Crowded Scenes: Deep Nested Adversarial + Learning and A New Benchmark for Multi-Human Parsing", ACM MM'2018. + More details can be found in the `paper + `__ + + MHP keypoints:: + + 0: "right ankle", + 1: "right knee", + 2: "right hip", + 3: "left hip", + 4: "left knee", + 5: "left ankle", + 6: "pelvis", + 7: "thorax", + 8: "upper neck", + 9: "head top", + 10: "right wrist", + 11: "right elbow", + 12: "right shoulder", + 13: "left shoulder", + 14: "left elbow", + 15: "left wrist", + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/mhp.py') diff --git a/mmpose/datasets/datasets/body/mpii_dataset.py b/mmpose/datasets/datasets/body/mpii_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..237f1ab2b61a6bda816e626b987de31560fda22a --- /dev/null +++ b/mmpose/datasets/datasets/body/mpii_dataset.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np +from mmengine.fileio import exists, get_local_path +from scipy.io import loadmat + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_cs2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class MpiiDataset(BaseCocoStyleDataset): + """MPII Dataset for pose estimation. + + "2D Human Pose Estimation: New Benchmark and State of the Art Analysis" + ,CVPR'2014. More details can be found in the `paper + `__ . + + MPII keypoints:: + + 0: 'right_ankle' + 1: 'right_knee', + 2: 'right_hip', + 3: 'left_hip', + 4: 'left_knee', + 5: 'left_ankle', + 6: 'pelvis', + 7: 'thorax', + 8: 'upper_neck', + 9: 'head_top', + 10: 'right_wrist', + 11: 'right_elbow', + 12: 'right_shoulder', + 13: 'left_shoulder', + 14: 'left_elbow', + 15: 'left_wrist' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + headbox_file (str, optional): The path of ``mpii_gt_val.mat`` which + provides the headboxes information used for ``PCKh``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/mpii.py') + + def __init__(self, + ann_file: str = '', + bbox_file: Optional[str] = None, + headbox_file: Optional[str] = None, + data_mode: str = 'topdown', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + + if headbox_file: + if data_mode != 'topdown': + raise ValueError( + f'{self.__class__.__name__} is set to {data_mode}: ' + 'mode, while "headbox_file" is only ' + 'supported in topdown mode.') + + if not test_mode: + raise ValueError( + f'{self.__class__.__name__} has `test_mode==False` ' + 'while "headbox_file" is only ' + 'supported when `test_mode==True`.') + + headbox_file_type = headbox_file[-3:] + allow_headbox_file_type = ['mat'] + if headbox_file_type not in allow_headbox_file_type: + raise KeyError( + f'The head boxes file type {headbox_file_type} is not ' + f'supported. Should be `mat` but got {headbox_file_type}.') + self.headbox_file = headbox_file + + super().__init__( + ann_file=ann_file, + bbox_file=bbox_file, + data_mode=data_mode, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + """Load data from annotations in MPII format.""" + + assert exists(self.ann_file), 'Annotation file does not exist' + with get_local_path(self.ann_file) as local_path: + with open(local_path) as anno_file: + self.anns = json.load(anno_file) + + if self.headbox_file: + assert exists(self.headbox_file), 'Headbox file does not exist' + with get_local_path(self.headbox_file) as local_path: + self.headbox_dict = loadmat(local_path) + headboxes_src = np.transpose(self.headbox_dict['headboxes_src'], + [2, 0, 1]) + SC_BIAS = 0.6 + + instance_list = [] + image_list = [] + used_img_ids = set() + ann_id = 0 + + # mpii bbox scales are normalized with factor 200. + pixel_std = 200. + + for idx, ann in enumerate(self.anns): + center = np.array(ann['center'], dtype=np.float32) + scale = np.array([ann['scale'], ann['scale']], + dtype=np.float32) * pixel_std + + # Adjust center/scale slightly to avoid cropping limbs + if center[0] != -1: + center[1] = center[1] + 15. / pixel_std * scale[1] + + # MPII uses matlab format, index is 1-based, + # we should first convert to 0-based index + center = center - 1 + + # unify shape with coco datasets + center = center.reshape(1, -1) + scale = scale.reshape(1, -1) + bbox = bbox_cs2xyxy(center, scale) + + # load keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + keypoints = np.array(ann['joints']).reshape(1, -1, 2) + keypoints_visible = np.array(ann['joints_vis']).reshape(1, -1) + + instance_info = { + 'id': ann_id, + 'img_id': int(ann['image'].split('.')[0]), + 'img_path': osp.join(self.data_prefix['img'], ann['image']), + 'bbox_center': center, + 'bbox_scale': scale, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + } + + if self.headbox_file: + # calculate the diagonal length of head box as norm_factor + headbox = headboxes_src[idx] + head_size = np.linalg.norm(headbox[1] - headbox[0], axis=0) + head_size *= SC_BIAS + instance_info['head_size'] = head_size.reshape(1, -1) + + if instance_info['img_id'] not in used_img_ids: + used_img_ids.add(instance_info['img_id']) + image_list.append({ + 'img_id': instance_info['img_id'], + 'img_path': instance_info['img_path'], + }) + + instance_list.append(instance_info) + ann_id = ann_id + 1 + + return instance_list, image_list diff --git a/mmpose/datasets/datasets/body/mpii_trb_dataset.py b/mmpose/datasets/datasets/body/mpii_trb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bb96ad876f483dd0b01946030485a53defcf41c8 --- /dev/null +++ b/mmpose/datasets/datasets/body/mpii_trb_dataset.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import List, Tuple + +import numpy as np +from mmengine.fileio import exists, get_local_path + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_cs2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class MpiiTrbDataset(BaseCocoStyleDataset): + """MPII-TRB Dataset dataset for pose estimation. + + "TRB: A Novel Triplet Representation for Understanding 2D Human Body", + ICCV'2019. More details can be found in the `paper + `__ . + + MPII-TRB keypoints:: + + 0: 'left_shoulder' + 1: 'right_shoulder' + 2: 'left_elbow' + 3: 'right_elbow' + 4: 'left_wrist' + 5: 'right_wrist' + 6: 'left_hip' + 7: 'right_hip' + 8: 'left_knee' + 9: 'right_knee' + 10: 'left_ankle' + 11: 'right_ankle' + 12: 'head' + 13: 'neck' + + 14: 'right_neck' + 15: 'left_neck' + 16: 'medial_right_shoulder' + 17: 'lateral_right_shoulder' + 18: 'medial_right_bow' + 19: 'lateral_right_bow' + 20: 'medial_right_wrist' + 21: 'lateral_right_wrist' + 22: 'medial_left_shoulder' + 23: 'lateral_left_shoulder' + 24: 'medial_left_bow' + 25: 'lateral_left_bow' + 26: 'medial_left_wrist' + 27: 'lateral_left_wrist' + 28: 'medial_right_hip' + 29: 'lateral_right_hip' + 30: 'medial_right_knee' + 31: 'lateral_right_knee' + 32: 'medial_right_ankle' + 33: 'lateral_right_ankle' + 34: 'medial_left_hip' + 35: 'lateral_left_hip' + 36: 'medial_left_knee' + 37: 'lateral_left_knee' + 38: 'medial_left_ankle' + 39: 'lateral_left_ankle' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/mpii_trb.py') + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + """Load data from annotations in MPII-TRB format.""" + + assert exists(self.ann_file), 'Annotation file does not exist' + with get_local_path(self.ann_file) as local_path: + with open(local_path) as anno_file: + self.data = json.load(anno_file) + + imgid2info = {img['id']: img for img in self.data['images']} + + instance_list = [] + image_list = [] + used_img_ids = set() + + # mpii-trb bbox scales are normalized with factor 200. + pixel_std = 200. + + for ann in self.data['annotations']: + img_id = ann['image_id'] + + # center, scale in shape [1, 2] and bbox in [1, 4] + center = np.array([ann['center']], dtype=np.float32) + scale = np.array([[ann['scale'], ann['scale']]], + dtype=np.float32) * pixel_std + bbox = bbox_cs2xyxy(center, scale) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + img_path = osp.join(self.data_prefix['img'], + imgid2info[img_id]['file_name']) + + instance_info = { + 'id': ann['id'], + 'img_id': img_id, + 'img_path': img_path, + 'bbox_center': center, + 'bbox_scale': scale, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': ann['num_joints'], + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + } + + # val set + if 'headbox' in ann: + instance_info['headbox'] = np.array( + ann['headbox'], dtype=np.float32) + + instance_list.append(instance_info) + if instance_info['img_id'] not in used_img_ids: + used_img_ids.add(instance_info['img_id']) + image_list.append({ + 'img_id': instance_info['img_id'], + 'img_path': instance_info['img_path'], + }) + + instance_list = sorted(instance_list, key=lambda x: x['id']) + return instance_list, image_list diff --git a/mmpose/datasets/datasets/body/ochuman_dataset.py b/mmpose/datasets/datasets/body/ochuman_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..695d090ea998dd530e0f65f902916107e77c4f6d --- /dev/null +++ b/mmpose/datasets/datasets/body/ochuman_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class OCHumanDataset(BaseCocoStyleDataset): + """OChuman dataset for pose estimation. + + "Pose2Seg: Detection Free Human Instance Segmentation", CVPR'2019. + More details can be found in the `paper + `__ . + + "Occluded Human (OCHuman)" dataset contains 8110 heavily occluded + human instances within 4731 images. OCHuman dataset is designed for + validation and testing. To evaluate on OCHuman, the model should be + trained on COCO training set, and then test the robustness of the + model to occlusion using OCHuman. + + OCHuman keypoints (same as COCO):: + + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/ochuman.py') diff --git a/mmpose/datasets/datasets/body/posetrack18_dataset.py b/mmpose/datasets/datasets/body/posetrack18_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b8110c107f6869085ed795c8f1f0338d2c6ed21d --- /dev/null +++ b/mmpose/datasets/datasets/body/posetrack18_dataset.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class PoseTrack18Dataset(BaseCocoStyleDataset): + """PoseTrack18 dataset for pose estimation. + + "Posetrack: A benchmark for human pose estimation and tracking", CVPR'2018. + More details can be found in the `paper + `__ . + + PoseTrack2018 keypoints:: + + 0: 'nose', + 1: 'head_bottom', + 2: 'head_top', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/posetrack18.py') diff --git a/mmpose/datasets/datasets/body/posetrack18_video_dataset.py b/mmpose/datasets/datasets/body/posetrack18_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5fe8646c54c6205b7bf6e0e1889dc13db02b02 --- /dev/null +++ b/mmpose/datasets/datasets/body/posetrack18_video_dataset.py @@ -0,0 +1,389 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Callable, List, Optional, Sequence, Union + +import numpy as np +from mmengine.fileio import exists, get_local_path, load +from mmengine.utils import is_list_of +from xtcocotools.coco import COCO + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_xywh2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class PoseTrack18VideoDataset(BaseCocoStyleDataset): + """PoseTrack18 dataset for video pose estimation. + + "Posetrack: A benchmark for human pose estimation and tracking", CVPR'2018. + More details can be found in the `paper + `__ . + + PoseTrack2018 keypoints:: + + 0: 'nose', + 1: 'head_bottom', + 2: 'head_top', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + frame_weights (List[Union[int, float]] ): The weight of each frame + for aggregation. The first weight is for the center frame, then on + ascending order of frame indices. Note that the length of + ``frame_weights`` should be consistent with the number of sampled + frames. Default: [0.0, 1.0] + frame_sampler_mode (str): Specifies the mode of frame sampler: + ``'fixed'`` or ``'random'``. In ``'fixed'`` mode, each frame + index relative to the center frame is fixed, specified by + ``frame_indices``, while in ``'random'`` mode, each frame index + relative to the center frame is sampled from ``frame_range`` + with certain randomness. Default: ``'random'``. + frame_range (int | List[int], optional): The sampling range of + supporting frames in the same video for center frame. + Only valid when ``frame_sampler_mode`` is ``'random'``. + Default: ``None``. + num_sampled_frame(int, optional): The number of sampled frames, except + the center frame. Only valid when ``frame_sampler_mode`` is + ``'random'``. Default: 1. + frame_indices (Sequence[int], optional): The sampled frame indices, + including the center frame indicated by 0. Only valid when + ``frame_sampler_mode`` is ``'fixed'``. Default: ``None``. + ph_fill_len (int): The length of the placeholder to fill in the + image filenames. Default: 6 + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img='')``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/posetrack18.py') + + def __init__(self, + ann_file: str = '', + bbox_file: Optional[str] = None, + data_mode: str = 'topdown', + frame_weights: List[Union[int, float]] = [0.0, 1.0], + frame_sampler_mode: str = 'random', + frame_range: Optional[Union[int, List[int]]] = None, + num_sampled_frame: Optional[int] = None, + frame_indices: Optional[Sequence[int]] = None, + ph_fill_len: int = 6, + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + assert sum(frame_weights) == 1, 'Invalid `frame_weights`: should sum'\ + f' to 1.0, but got {frame_weights}.' + for weight in frame_weights: + assert weight >= 0, 'frame_weight can not be a negative value.' + self.frame_weights = np.array(frame_weights) + + if frame_sampler_mode not in {'fixed', 'random'}: + raise ValueError( + f'{self.__class__.__name__} got invalid frame_sampler_mode: ' + f'{frame_sampler_mode}. Should be `"fixed"` or `"random"`.') + self.frame_sampler_mode = frame_sampler_mode + + if frame_sampler_mode == 'random': + assert frame_range is not None, \ + '`frame_sampler_mode` is set as `random`, ' \ + 'please specify the `frame_range`.' + + if isinstance(frame_range, int): + assert frame_range >= 0, \ + 'frame_range can not be a negative value.' + self.frame_range = [-frame_range, frame_range] + + elif isinstance(frame_range, Sequence): + assert len(frame_range) == 2, 'The length must be 2.' + assert frame_range[0] <= 0 and frame_range[ + 1] >= 0 and frame_range[1] > frame_range[ + 0], 'Invalid `frame_range`' + for i in frame_range: + assert isinstance(i, int), 'Each element must be int.' + self.frame_range = frame_range + else: + raise TypeError( + f'The type of `frame_range` must be int or Sequence, ' + f'but got {type(frame_range)}.') + + assert num_sampled_frame is not None, \ + '`frame_sampler_mode` is set as `random`, please specify ' \ + '`num_sampled_frame`, e.g. the number of sampled frames.' + + assert len(frame_weights) == num_sampled_frame + 1, \ + f'the length of frame_weights({len(frame_weights)}) '\ + f'does not match the number of sampled adjacent '\ + f'frames({num_sampled_frame})' + self.frame_indices = None + self.num_sampled_frame = num_sampled_frame + + if frame_sampler_mode == 'fixed': + assert frame_indices is not None, \ + '`frame_sampler_mode` is set as `fixed`, ' \ + 'please specify the `frame_indices`.' + assert len(frame_weights) == len(frame_indices), \ + f'the length of frame_weights({len(frame_weights)}) does not '\ + f'match the length of frame_indices({len(frame_indices)}).' + frame_indices.sort() + self.frame_indices = frame_indices + self.frame_range = None + self.num_sampled_frame = None + + self.ph_fill_len = ph_fill_len + + super().__init__( + ann_file=ann_file, + bbox_file=bbox_file, + data_mode=data_mode, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + # filter invalid instance + if 'bbox' not in ann or 'keypoints' not in ann or max( + ann['keypoints']) == 0: + return None + + img_w, img_h = img['width'], img['height'] + # get the bbox of the center frame + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # get the keypoints of the center frame + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + # deal with multiple image paths + img_paths: list = [] + # get the image path of the center frame + center_img_path = osp.join(self.data_prefix['img'], img['file_name']) + # append the center image path first + img_paths.append(center_img_path) + + # select the frame indices + if self.frame_sampler_mode == 'fixed': + indices = self.frame_indices + else: # self.frame_sampler_mode == 'random': + low, high = self.frame_range + indices = np.random.randint(low, high + 1, self.num_sampled_frame) + + nframes = int(img['nframes']) + file_name = img['file_name'] + ref_idx = int(osp.splitext(osp.basename(file_name))[0]) + + for idx in indices: + if self.test_mode and idx == 0: + continue + # the supporting frame index + support_idx = ref_idx + idx + # clip the frame index to make sure that it does not exceed + # the boundings of frame indices + support_idx = np.clip(support_idx, 0, nframes - 1) + sup_img_path = osp.join( + osp.dirname(center_img_path), + str(support_idx).zfill(self.ph_fill_len) + '.jpg') + + img_paths.append(sup_img_path) + + data_info = { + 'img_id': int(img['frame_id']), + 'img_path': img_paths, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': ann['num_keypoints'], + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'frame_weights': self.frame_weights, + 'id': ann['id'], + } + + return data_info + + def _load_detection_results(self) -> List[dict]: + """Load data from detection results with dummy keypoint annotations.""" + assert exists(self.ann_file), 'Annotation file does not exist' + assert exists(self.bbox_file), 'Bbox file does not exist' + + # load detection results + det_results = load(self.bbox_file) + assert is_list_of(det_results, dict) + + # load coco annotations to build image id-to-name index + with get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + + # mapping image name to id + name2id = {} + # mapping image id to name + id2name = {} + for img_id, image in self.coco.imgs.items(): + file_name = image['file_name'] + id2name[img_id] = file_name + name2id[file_name] = img_id + + num_keypoints = self.metainfo['num_keypoints'] + data_list = [] + id_ = 0 + for det in det_results: + # remove non-human instances + if det['category_id'] != 1: + continue + + # get the predicted bbox and bbox_score + bbox_xywh = np.array( + det['bbox'][:4], dtype=np.float32).reshape(1, 4) + bbox = bbox_xywh2xyxy(bbox_xywh) + bbox_score = np.array(det['score'], dtype=np.float32).reshape(1) + + # use dummy keypoint location and visibility + keypoints = np.zeros((1, num_keypoints, 2), dtype=np.float32) + keypoints_visible = np.ones((1, num_keypoints), dtype=np.float32) + + # deal with different bbox file formats + if 'nframes' in det: + nframes = int(det['nframes']) + else: + if 'image_name' in det: + img_id = name2id[det['image_name']] + else: + img_id = det['image_id'] + img_ann = self.coco.loadImgs(img_id)[0] + nframes = int(img_ann['nframes']) + + # deal with multiple image paths + img_paths: list = [] + if 'image_name' in det: + image_name = det['image_name'] + else: + image_name = id2name[det['image_id']] + # get the image path of the center frame + center_img_path = osp.join(self.data_prefix['img'], image_name) + # append the center image path first + img_paths.append(center_img_path) + + # "images/val/012834_mpii_test/000000.jpg" -->> "000000.jpg" + center_image_name = image_name.split('/')[-1] + ref_idx = int(center_image_name.replace('.jpg', '')) + + # select the frame indices + if self.frame_sampler_mode == 'fixed': + indices = self.frame_indices + else: # self.frame_sampler_mode == 'random': + low, high = self.frame_range + indices = np.random.randint(low, high + 1, + self.num_sampled_frame) + + for idx in indices: + if self.test_mode and idx == 0: + continue + # the supporting frame index + support_idx = ref_idx + idx + # clip the frame index to make sure that it does not exceed + # the boundings of frame indices + support_idx = np.clip(support_idx, 0, nframes - 1) + sup_img_path = center_img_path.replace( + center_image_name, + str(support_idx).zfill(self.ph_fill_len) + '.jpg') + + img_paths.append(sup_img_path) + + data_list.append({ + 'img_id': det['image_id'], + 'img_path': img_paths, + 'frame_weights': self.frame_weights, + 'bbox': bbox, + 'bbox_score': bbox_score, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'id': id_, + }) + + id_ += 1 + + return data_list diff --git a/mmpose/datasets/datasets/face/__init__.py b/mmpose/datasets/datasets/face/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..700cb605f7e5bb5177ce382b0f162edd8959d277 --- /dev/null +++ b/mmpose/datasets/datasets/face/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .aflw_dataset import AFLWDataset +from .coco_wholebody_face_dataset import CocoWholeBodyFaceDataset +from .cofw_dataset import COFWDataset +from .face_300w_dataset import Face300WDataset +from .lapa_dataset import LapaDataset +from .wflw_dataset import WFLWDataset + +__all__ = [ + 'Face300WDataset', 'WFLWDataset', 'AFLWDataset', 'COFWDataset', + 'CocoWholeBodyFaceDataset', 'LapaDataset' +] diff --git a/mmpose/datasets/datasets/face/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccd906c49009614660108413fc1c45a16d52ddc4 Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/__pycache__/aflw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/aflw_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..724b99e6a5770018a7e1fbc36c72c6bbb7c9012f Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/aflw_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/__pycache__/coco_wholebody_face_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/coco_wholebody_face_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d26f6fe00f0077c32a3c0fc34a7156b1b8b8ec5 Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/coco_wholebody_face_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/__pycache__/cofw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/cofw_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..918a8bd74e3060ed34160702559144e4955cb837 Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/cofw_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/__pycache__/face_300w_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/face_300w_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcbd9ad4e84e84cb6d25de643a75187d93045813 Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/face_300w_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/__pycache__/lapa_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/lapa_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ed2e5ab63ccf931c226c7bf17857963df9e610 Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/lapa_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/__pycache__/wflw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/wflw_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b92d0794ade8e34977b148a230e8e3469567241 Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/wflw_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/face/aflw_dataset.py b/mmpose/datasets/datasets/face/aflw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..deda0974bb58ba52371f727e788342b5502987a5 --- /dev/null +++ b/mmpose/datasets/datasets/face/aflw_dataset.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_cs2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class AFLWDataset(BaseCocoStyleDataset): + """AFLW dataset for face keypoint localization. + + "Annotated Facial Landmarks in the Wild: A Large-scale, + Real-world Database for Facial Landmark Localization". + In Proc. First IEEE International Workshop on Benchmarking + Facial Image Analysis Technologies, 2011. + + The landmark annotations follow the 19 points mark-up. The definition + can be found in `https://www.tugraz.at/institute/icg/research` + `/team-bischof/lrs/downloads/aflw/` + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/aflw.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw Face AFLW annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # aflw bbox scales are normalized with factor 200. + pixel_std = 200. + + # center, scale in shape [1, 2] and bbox in [1, 4] + center = np.array([ann['center']], dtype=np.float32) + scale = np.array([[ann['scale'], ann['scale']]], + dtype=np.float32) * pixel_std + bbox = bbox_cs2xyxy(center, scale) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = ann['num_keypoints'] + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_center': center, + 'bbox_scale': scale, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + + if self.test_mode: + # 'box_size' is used as normalization factor + assert 'box_size' in ann, '"box_size" is missing in annotation, '\ + 'which is required for evaluation.' + data_info['box_size'] = ann['box_size'] + + return data_info diff --git a/mmpose/datasets/datasets/face/coco_wholebody_face_dataset.py b/mmpose/datasets/datasets/face/coco_wholebody_face_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bc2c5be386012a341879a3910dcf72e5672e5d6f --- /dev/null +++ b/mmpose/datasets/datasets/face/coco_wholebody_face_dataset.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class CocoWholeBodyFaceDataset(BaseCocoStyleDataset): + """CocoWholeBodyDataset for face keypoint localization. + + `Whole-Body Human Pose Estimation in the Wild', ECCV'2020. + More details can be found in the `paper + `__ . + + The face landmark annotations follow the 68 points mark-up. + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict( + from_file='configs/_base_/datasets/coco_wholebody_face.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw CocoWholeBody Face annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + # filter invalid instance + if not ann['face_valid'] or max(ann['face_kpts']) <= 0: + return None + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['face_box'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['face_kpts'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + return data_info diff --git a/mmpose/datasets/datasets/face/cofw_dataset.py b/mmpose/datasets/datasets/face/cofw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec2a37efd8b7fc125ebd87df88bc9c99cd86250 --- /dev/null +++ b/mmpose/datasets/datasets/face/cofw_dataset.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class COFWDataset(BaseCocoStyleDataset): + """COFW dataset for face keypoint localization. + + "Robust face landmark estimation under occlusion", ICCV'2013. + + The landmark annotations follow the 29 points mark-up. The definition + can be found in `http://www.vision.caltech.edu/xpburgos/ICCV13/`__ . + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/cofw.py') diff --git a/mmpose/datasets/datasets/face/face_300w_dataset.py b/mmpose/datasets/datasets/face/face_300w_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c70e892b4f707dc5990566b760e0a2566eb4a53f --- /dev/null +++ b/mmpose/datasets/datasets/face/face_300w_dataset.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_cs2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class Face300WDataset(BaseCocoStyleDataset): + """300W dataset for face keypoint localization. + + "300 faces In-the-wild challenge: Database and results", + Image and Vision Computing (IMAVIS) 2019. + + The landmark annotations follow the 68 points mark-up. The definition + can be found in `https://ibug.doc.ic.ac.uk/resources/300-W/`. + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/300w.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw Face300W annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # 300w bbox scales are normalized with factor 200. + pixel_std = 200. + + # center, scale in shape [1, 2] and bbox in [1, 4] + center = np.array([ann['center']], dtype=np.float32) + scale = np.array([[ann['scale'], ann['scale']]], + dtype=np.float32) * pixel_std + bbox = bbox_cs2xyxy(center, scale) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = ann['num_keypoints'] + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_center': center, + 'bbox_scale': scale, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + return data_info diff --git a/mmpose/datasets/datasets/face/lapa_dataset.py b/mmpose/datasets/datasets/face/lapa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5bdc4ec08cebe690ae1f5f2a659e9c087634ec --- /dev/null +++ b/mmpose/datasets/datasets/face/lapa_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class LapaDataset(BaseCocoStyleDataset): + """LaPa dataset for face keypoint localization. + + "A New Dataset and Boundary-Attention Semantic Segmentation + for Face Parsing", AAAI'2020. + + The landmark annotations follow the 106 points mark-up. The definition + can be found in `https://github.com/JDAI-CV/lapa-dataset/`__ . + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/lapa.py') diff --git a/mmpose/datasets/datasets/face/wflw_dataset.py b/mmpose/datasets/datasets/face/wflw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1c23053ce87fc92a234334e637e7a8e0402a9e --- /dev/null +++ b/mmpose/datasets/datasets/face/wflw_dataset.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_cs2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class WFLWDataset(BaseCocoStyleDataset): + """WFLW dataset for face keypoint localization. + + "Look at Boundary: A Boundary-Aware Face Alignment Algorithm", + CVPR'2018. + + The landmark annotations follow the 98 points mark-up. The definition + can be found in `https://wywu.github.io/projects/LAB/WFLW.html`__ . + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/wflw.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw Face WFLW annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # wflw bbox scales are normalized with factor 200. + pixel_std = 200. + + # center, scale in shape [1, 2] and bbox in [1, 4] + center = np.array([ann['center']], dtype=np.float32) + scale = np.array([[ann['scale'], ann['scale']]], + dtype=np.float32) * pixel_std + bbox = bbox_cs2xyxy(center, scale) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = ann['num_keypoints'] + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_center': center, + 'bbox_scale': scale, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + return data_info diff --git a/mmpose/datasets/datasets/fashion/__init__.py b/mmpose/datasets/datasets/fashion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8be25dede3d16dfb7754c794d86d7f236e8f647b --- /dev/null +++ b/mmpose/datasets/datasets/fashion/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deepfashion2_dataset import DeepFashion2Dataset +from .deepfashion_dataset import DeepFashionDataset + +__all__ = ['DeepFashionDataset', 'DeepFashion2Dataset'] diff --git a/mmpose/datasets/datasets/fashion/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/fashion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..023eb00097f5274f331bfd4571e51264af1edb4a Binary files /dev/null and b/mmpose/datasets/datasets/fashion/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/fashion/__pycache__/deepfashion2_dataset.cpython-38.pyc b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion2_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e773929a8a21d85be41798b06303fa59ffd5580c Binary files /dev/null and b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion2_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/fashion/__pycache__/deepfashion_dataset.cpython-38.pyc b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b90b5436aca0c00ed0872ab0f9e27c0f00d99b5 Binary files /dev/null and b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/fashion/deepfashion2_dataset.py b/mmpose/datasets/datasets/fashion/deepfashion2_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c3cde9bf97be254927aa6a06f46bdcc225f14283 --- /dev/null +++ b/mmpose/datasets/datasets/fashion/deepfashion2_dataset.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module(name='DeepFashion2Dataset') +class DeepFashion2Dataset(BaseCocoStyleDataset): + """DeepFashion2 dataset for fashion landmark detection.""" + + METAINFO: dict = dict(from_file='configs/_base_/datasets/deepfashion2.py') diff --git a/mmpose/datasets/datasets/fashion/deepfashion_dataset.py b/mmpose/datasets/datasets/fashion/deepfashion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a0aa4937323e41333d48a82a11862e68ffc697f0 --- /dev/null +++ b/mmpose/datasets/datasets/fashion/deepfashion_dataset.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Sequence, Union + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class DeepFashionDataset(BaseCocoStyleDataset): + """DeepFashion dataset (full-body clothes) for fashion landmark detection. + + "DeepFashion: Powering Robust Clothes Recognition + and Retrieval with Rich Annotations", CVPR'2016. + "Fashion Landmark Detection in the Wild", ECCV'2016. + + The dataset contains 3 categories for full-body, upper-body and lower-body. + + Fashion landmark indexes for upper-body clothes:: + + 0: 'left collar', + 1: 'right collar', + 2: 'left sleeve', + 3: 'right sleeve', + 4: 'left hem', + 5: 'right hem' + + Fashion landmark indexes for lower-body clothes:: + + 0: 'left waistline', + 1: 'right waistline', + 2: 'left hem', + 3: 'right hem' + + Fashion landmark indexes for full-body clothes:: + + 0: 'left collar', + 1: 'right collar', + 2: 'left sleeve', + 3: 'right sleeve', + 4: 'left waistline', + 5: 'right waistline', + 6: 'left hem', + 7: 'right hem' + + Args: + ann_file (str): Annotation file path. Default: ''. + subset (str): Specifies the subset of body: ``'full'``, ``'upper'`` or + ``'lower'``. Default: '', which means ``'full'``. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img='')``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + def __init__(self, + ann_file: str = '', + subset: str = '', + bbox_file: Optional[str] = None, + data_mode: str = 'topdown', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + self._check_subset_and_metainfo(subset) + + super().__init__( + ann_file=ann_file, + bbox_file=bbox_file, + data_mode=data_mode, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @classmethod + def _check_subset_and_metainfo(cls, subset: str = '') -> None: + """Check the subset of body and set the corresponding metainfo. + + Args: + subset(str): the subset of body: could be ``'full'``, ``'upper'`` + or ``'lower'``. Default: '', which means ``'full'``. + """ + if subset == '' or subset == 'full': + cls.METAINFO = dict( + from_file='configs/_base_/datasets/deepfashion_full.py') + elif subset == 'upper': + cls.METAINFO = dict( + from_file='configs/_base_/datasets/deepfashion_upper.py') + elif subset == 'lower': + cls.METAINFO = dict( + from_file='configs/_base_/datasets/deepfashion_lower.py') + else: + raise ValueError( + f'{cls.__class__.__name__} got invalid subset: ' + f'{subset}. Should be "full", "lower" or "upper".') diff --git a/mmpose/datasets/datasets/hand/__init__.py b/mmpose/datasets/datasets/hand/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e2222be9e981c51b38ac879b534966ec6aa861 --- /dev/null +++ b/mmpose/datasets/datasets/hand/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coco_wholebody_hand_dataset import CocoWholeBodyHandDataset +from .freihand_dataset import FreiHandDataset +from .onehand10k_dataset import OneHand10KDataset +from .panoptic_hand2d_dataset import PanopticHand2DDataset +from .rhd2d_dataset import Rhd2DDataset + +__all__ = [ + 'OneHand10KDataset', 'FreiHandDataset', 'PanopticHand2DDataset', + 'Rhd2DDataset', 'CocoWholeBodyHandDataset' +] diff --git a/mmpose/datasets/datasets/hand/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4996f11b284444f4f56f9c59626bd57135355a37 Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/hand/__pycache__/coco_wholebody_hand_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/coco_wholebody_hand_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d99b789df71e7d2ab77040055c0c15458ec2d16 Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/coco_wholebody_hand_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/hand/__pycache__/freihand_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/freihand_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ab8c55ae6b9e45ee4763bb84bc3a45b9f7c02f5 Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/freihand_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/hand/__pycache__/onehand10k_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/onehand10k_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..181339aa6f8122d17590f93c3d016fd44a71966f Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/onehand10k_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/hand/__pycache__/panoptic_hand2d_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/panoptic_hand2d_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55a5d3d52cf5d03826d21baf37887eaf0145622a Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/panoptic_hand2d_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/hand/__pycache__/rhd2d_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/rhd2d_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0602b0335da5d433cdfe49b6be1fa13ab689a5df Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/rhd2d_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/hand/coco_wholebody_hand_dataset.py b/mmpose/datasets/datasets/hand/coco_wholebody_hand_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dba0132f584f812c323ea040675931a14ca8841c --- /dev/null +++ b/mmpose/datasets/datasets/hand/coco_wholebody_hand_dataset.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List, Tuple + +import numpy as np +from mmengine.fileio import exists, get_local_path +from xtcocotools.coco import COCO + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_xywh2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class CocoWholeBodyHandDataset(BaseCocoStyleDataset): + """CocoWholeBodyDataset for hand pose estimation. + + "Whole-Body Human Pose Estimation in the Wild", ECCV'2020. + More details can be found in the `paper + `__ . + + COCO-WholeBody Hand keypoints:: + + 0: 'wrist', + 1: 'thumb1', + 2: 'thumb2', + 3: 'thumb3', + 4: 'thumb4', + 5: 'forefinger1', + 6: 'forefinger2', + 7: 'forefinger3', + 8: 'forefinger4', + 9: 'middle_finger1', + 10: 'middle_finger2', + 11: 'middle_finger3', + 12: 'middle_finger4', + 13: 'ring_finger1', + 14: 'ring_finger2', + 15: 'ring_finger3', + 16: 'ring_finger4', + 17: 'pinky_finger1', + 18: 'pinky_finger2', + 19: 'pinky_finger3', + 20: 'pinky_finger4' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict( + from_file='configs/_base_/datasets/coco_wholebody_hand.py') + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + """Load data from annotations in COCO format.""" + + assert exists(self.ann_file), 'Annotation file does not exist' + + with get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + instance_list = [] + image_list = [] + id = 0 + + for img_id in self.coco.getImgIds(): + img = self.coco.loadImgs(img_id)[0] + + img.update({ + 'img_id': + img_id, + 'img_path': + osp.join(self.data_prefix['img'], img['file_name']), + }) + image_list.append(img) + + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False) + anns = self.coco.loadAnns(ann_ids) + for ann in anns: + for type in ['left', 'right']: + # filter invalid hand annotations, there might be two + # valid instances (left and right hand) in one image + if ann[f'{type}hand_valid'] and max( + ann[f'{type}hand_kpts']) > 0: + + bbox_xywh = np.array( + ann[f'{type}hand_box'], + dtype=np.float32).reshape(1, 4) + + bbox = bbox_xywh2xyxy(bbox_xywh) + + _keypoints = np.array( + ann[f'{type}hand_kpts'], + dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + instance_info = { + 'img_id': ann['image_id'], + 'img_path': img['img_path'], + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'segmentation': ann['segmentation'], + 'id': id, + } + instance_list.append(instance_info) + id = id + 1 + + instance_list = sorted(instance_list, key=lambda x: x['id']) + return instance_list, image_list diff --git a/mmpose/datasets/datasets/hand/freihand_dataset.py b/mmpose/datasets/datasets/hand/freihand_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0e23cdd577d12e6d20656fde59f7da58a45150 --- /dev/null +++ b/mmpose/datasets/datasets/hand/freihand_dataset.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class FreiHandDataset(BaseCocoStyleDataset): + """FreiHand dataset for hand pose estimation. + + "FreiHAND: A Dataset for Markerless Capture of Hand Pose + and Shape from Single RGB Images", ICCV'2019. + More details can be found in the `paper + `__ . + + FreiHand keypoints:: + + 0: 'wrist', + 1: 'thumb1', + 2: 'thumb2', + 3: 'thumb3', + 4: 'thumb4', + 5: 'forefinger1', + 6: 'forefinger2', + 7: 'forefinger3', + 8: 'forefinger4', + 9: 'middle_finger1', + 10: 'middle_finger2', + 11: 'middle_finger3', + 12: 'middle_finger4', + 13: 'ring_finger1', + 14: 'ring_finger2', + 15: 'ring_finger3', + 16: 'ring_finger4', + 17: 'pinky_finger1', + 18: 'pinky_finger2', + 19: 'pinky_finger3', + 20: 'pinky_finger4' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/freihand2d.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw COCO annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # use the entire image which is 224x224 + bbox = np.array([0, 0, 224, 224], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'segmentation': ann['segmentation'], + 'id': ann['id'], + } + + return data_info diff --git a/mmpose/datasets/datasets/hand/onehand10k_dataset.py b/mmpose/datasets/datasets/hand/onehand10k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3519ace560ef70ce680955bfa82d52c1a11b6b3e --- /dev/null +++ b/mmpose/datasets/datasets/hand/onehand10k_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class OneHand10KDataset(BaseCocoStyleDataset): + """OneHand10K dataset for hand pose estimation. + + "Mask-pose Cascaded CNN for 2D Hand Pose Estimation from + Single Color Images", TCSVT'2019. + More details can be found in the `paper + `__ . + + OneHand10K keypoints:: + + 0: 'wrist', + 1: 'thumb1', + 2: 'thumb2', + 3: 'thumb3', + 4: 'thumb4', + 5: 'forefinger1', + 6: 'forefinger2', + 7: 'forefinger3', + 8: 'forefinger4', + 9: 'middle_finger1', + 10: 'middle_finger2', + 11: 'middle_finger3', + 12: 'middle_finger4', + 13: 'ring_finger1', + 14: 'ring_finger2', + 15: 'ring_finger3', + 16: 'ring_finger4', + 17: 'pinky_finger1', + 18: 'pinky_finger2', + 19: 'pinky_finger3', + 20: 'pinky_finger4' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/onehand10k.py') diff --git a/mmpose/datasets/datasets/hand/panoptic_hand2d_dataset.py b/mmpose/datasets/datasets/hand/panoptic_hand2d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..26d364840ebe5756687a72a4de52b0213ffdcea2 --- /dev/null +++ b/mmpose/datasets/datasets/hand/panoptic_hand2d_dataset.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class PanopticHand2DDataset(BaseCocoStyleDataset): + """Panoptic 2D dataset for hand pose estimation. + + "Hand Keypoint Detection in Single Images using Multiview + Bootstrapping", CVPR'2017. + More details can be found in the `paper + `__ . + + Panoptic keypoints:: + + 0: 'wrist', + 1: 'thumb1', + 2: 'thumb2', + 3: 'thumb3', + 4: 'thumb4', + 5: 'forefinger1', + 6: 'forefinger2', + 7: 'forefinger3', + 8: 'forefinger4', + 9: 'middle_finger1', + 10: 'middle_finger2', + 11: 'middle_finger3', + 12: 'middle_finger4', + 13: 'ring_finger1', + 14: 'ring_finger2', + 15: 'ring_finger3', + 16: 'ring_finger4', + 17: 'pinky_finger1', + 18: 'pinky_finger2', + 19: 'pinky_finger3', + 20: 'pinky_finger4' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict( + from_file='configs/_base_/datasets/panoptic_hand2d.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw COCO annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'segmentation': ann['segmentation'], + 'head_size': ann['head_size'], + 'id': ann['id'], + } + + return data_info diff --git a/mmpose/datasets/datasets/hand/rhd2d_dataset.py b/mmpose/datasets/datasets/hand/rhd2d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc4301590a5f8c8d474b0ef37de4d03309ad0b9 --- /dev/null +++ b/mmpose/datasets/datasets/hand/rhd2d_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class Rhd2DDataset(BaseCocoStyleDataset): + """Rendered Handpose Dataset for hand pose estimation. + + "Learning to Estimate 3D Hand Pose from Single RGB Images", + ICCV'2017. + More details can be found in the `paper + `__ . + + Rhd keypoints:: + + 0: 'wrist', + 1: 'thumb4', + 2: 'thumb3', + 3: 'thumb2', + 4: 'thumb1', + 5: 'forefinger4', + 6: 'forefinger3', + 7: 'forefinger2', + 8: 'forefinger1', + 9: 'middle_finger4', + 10: 'middle_finger3', + 11: 'middle_finger2', + 12: 'middle_finger1', + 13: 'ring_finger4', + 14: 'ring_finger3', + 15: 'ring_finger2', + 16: 'ring_finger1', + 17: 'pinky_finger4', + 18: 'pinky_finger3', + 19: 'pinky_finger2', + 20: 'pinky_finger1' + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/rhd2d.py') diff --git a/mmpose/datasets/datasets/utils.py b/mmpose/datasets/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5140126163d49eb92a25fe0db1d5d4ebb7144fa6 --- /dev/null +++ b/mmpose/datasets/datasets/utils.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings + +import numpy as np +from mmengine import Config + + +def parse_pose_metainfo(metainfo: dict): + """Load meta information of pose dataset and check its integrity. + + Args: + metainfo (dict): Raw data of pose meta information, which should + contain following contents: + + - "dataset_name" (str): The name of the dataset + - "keypoint_info" (dict): The keypoint-related meta information, + e.g., name, upper/lower body, and symmetry + - "skeleton_info" (dict): The skeleton-related meta information, + e.g., start/end keypoint of limbs + - "joint_weights" (list[float]): The loss weights of keypoints + - "sigmas" (list[float]): The keypoint distribution parameters + to calculate OKS score. See `COCO keypoint evaluation + `__. + + An example of metainfo is shown as follows. + + .. code-block:: none + { + "dataset_name": "coco", + "keypoint_info": + { + 0: + { + "name": "nose", + "type": "upper", + "swap": "", + "color": [51, 153, 255], + }, + 1: + { + "name": "right_eye", + "type": "upper", + "swap": "left_eye", + "color": [51, 153, 255], + }, + ... + }, + "skeleton_info": + { + 0: + { + "link": ("left_ankle", "left_knee"), + "color": [0, 255, 0], + }, + ... + }, + "joint_weights": [1., 1., ...], + "sigmas": [0.026, 0.025, ...], + } + + + A special case is that `metainfo` can have the key "from_file", + which should be the path of a config file. In this case, the + actual metainfo will be loaded by: + + .. code-block:: python + metainfo = mmengine.Config.fromfile(metainfo['from_file']) + + Returns: + Dict: pose meta information that contains following contents: + + - "dataset_name" (str): Same as ``"dataset_name"`` in the input + - "num_keypoints" (int): Number of keypoints + - "keypoint_id2name" (dict): Mapping from keypoint id to name + - "keypoint_name2id" (dict): Mapping from keypoint name to id + - "upper_body_ids" (list): Ids of upper-body keypoint + - "lower_body_ids" (list): Ids of lower-body keypoint + - "flip_indices" (list): The Id of each keypoint's symmetric keypoint + - "flip_pairs" (list): The Ids of symmetric keypoint pairs + - "keypoint_colors" (numpy.ndarray): The keypoint color matrix of + shape [K, 3], where each row is the color of one keypint in bgr + - "num_skeleton_links" (int): The number of links + - "skeleton_links" (list): The links represented by Id pairs of start + and end points + - "skeleton_link_colors" (numpy.ndarray): The link color matrix + - "dataset_keypoint_weights" (numpy.ndarray): Same as the + ``"joint_weights"`` in the input + - "sigmas" (numpy.ndarray): Same as the ``"sigmas"`` in the input + """ + + if 'from_file' in metainfo: + cfg_file = metainfo['from_file'] + if not osp.isfile(cfg_file): + # Search configs in 'mmpose/.mim/configs/' in case that mmpose + # is installed in non-editable mode. + import mmpose + mmpose_path = osp.dirname(mmpose.__file__) + _cfg_file = osp.join(mmpose_path, '.mim', 'configs', '_base_', + 'datasets', osp.basename(cfg_file)) + if osp.isfile(_cfg_file): + warnings.warn( + f'The metainfo config file "{cfg_file}" does not exist. ' + f'A matched config file "{_cfg_file}" will be used ' + 'instead.') + cfg_file = _cfg_file + else: + raise FileNotFoundError( + f'The metainfo config file "{cfg_file}" does not exist.') + + # TODO: remove the nested structure of dataset_info + # metainfo = Config.fromfile(metainfo['from_file']) + metainfo = Config.fromfile(cfg_file).dataset_info + + # check data integrity + assert 'dataset_name' in metainfo + assert 'keypoint_info' in metainfo + assert 'skeleton_info' in metainfo + assert 'joint_weights' in metainfo + assert 'sigmas' in metainfo + + # parse metainfo + parsed = dict( + dataset_name=None, + num_keypoints=None, + keypoint_id2name={}, + keypoint_name2id={}, + upper_body_ids=[], + lower_body_ids=[], + flip_indices=[], + flip_pairs=[], + keypoint_colors=[], + num_skeleton_links=None, + skeleton_links=[], + skeleton_link_colors=[], + dataset_keypoint_weights=None, + sigmas=None, + ) + + parsed['dataset_name'] = metainfo['dataset_name'] + + # parse keypoint information + parsed['num_keypoints'] = len(metainfo['keypoint_info']) + + for kpt_id, kpt in metainfo['keypoint_info'].items(): + kpt_name = kpt['name'] + parsed['keypoint_id2name'][kpt_id] = kpt_name + parsed['keypoint_name2id'][kpt_name] = kpt_id + parsed['keypoint_colors'].append(kpt.get('color', [255, 128, 0])) + + kpt_type = kpt.get('type', '') + if kpt_type == 'upper': + parsed['upper_body_ids'].append(kpt_id) + elif kpt_type == 'lower': + parsed['lower_body_ids'].append(kpt_id) + + swap_kpt = kpt.get('swap', '') + if swap_kpt == kpt_name or swap_kpt == '': + parsed['flip_indices'].append(kpt_name) + else: + parsed['flip_indices'].append(swap_kpt) + pair = (swap_kpt, kpt_name) + if pair not in parsed['flip_pairs']: + parsed['flip_pairs'].append(pair) + + # parse skeleton information + parsed['num_skeleton_links'] = len(metainfo['skeleton_info']) + for _, sk in metainfo['skeleton_info'].items(): + parsed['skeleton_links'].append(sk['link']) + parsed['skeleton_link_colors'].append(sk.get('color', [96, 96, 255])) + + # parse extra information + parsed['dataset_keypoint_weights'] = np.array( + metainfo['joint_weights'], dtype=np.float32) + parsed['sigmas'] = np.array(metainfo['sigmas'], dtype=np.float32) + + # formatting + def _map(src, mapping: dict): + if isinstance(src, (list, tuple)): + cls = type(src) + return cls(_map(s, mapping) for s in src) + else: + return mapping[src] + + parsed['flip_pairs'] = _map( + parsed['flip_pairs'], mapping=parsed['keypoint_name2id']) + parsed['flip_indices'] = _map( + parsed['flip_indices'], mapping=parsed['keypoint_name2id']) + parsed['skeleton_links'] = _map( + parsed['skeleton_links'], mapping=parsed['keypoint_name2id']) + + parsed['keypoint_colors'] = np.array( + parsed['keypoint_colors'], dtype=np.uint8) + parsed['skeleton_link_colors'] = np.array( + parsed['skeleton_link_colors'], dtype=np.uint8) + + return parsed diff --git a/mmpose/datasets/datasets/wholebody/__init__.py b/mmpose/datasets/datasets/wholebody/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..156094c2b03b297ac5e1ad839cc970fc67a0ecf4 --- /dev/null +++ b/mmpose/datasets/datasets/wholebody/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coco_wholebody_dataset import CocoWholeBodyDataset +from .halpe_dataset import HalpeDataset + +__all__ = ['CocoWholeBodyDataset', 'HalpeDataset'] diff --git a/mmpose/datasets/datasets/wholebody/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/wholebody/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..931b681247e546fbc46f4a343f10edc4770d798d Binary files /dev/null and b/mmpose/datasets/datasets/wholebody/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/wholebody/__pycache__/coco_wholebody_dataset.cpython-38.pyc b/mmpose/datasets/datasets/wholebody/__pycache__/coco_wholebody_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6b2963204ad9ba6d4c41864820fd39d6f0fdb59 Binary files /dev/null and b/mmpose/datasets/datasets/wholebody/__pycache__/coco_wholebody_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/wholebody/__pycache__/halpe_dataset.cpython-38.pyc b/mmpose/datasets/datasets/wholebody/__pycache__/halpe_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..571e0a24eb11996a7c9c8891cc7672f5c1f54c64 Binary files /dev/null and b/mmpose/datasets/datasets/wholebody/__pycache__/halpe_dataset.cpython-38.pyc differ diff --git a/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..00a2ea418f95c1973881072256c455e3de9f2046 --- /dev/null +++ b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class CocoWholeBodyDataset(BaseCocoStyleDataset): + """CocoWholeBody dataset for pose estimation. + + "Whole-Body Human Pose Estimation in the Wild", ECCV'2020. + More details can be found in the `paper + `__ . + + COCO-WholeBody keypoints:: + + 0-16: 17 body keypoints, + 17-22: 6 foot keypoints, + 23-90: 68 face keypoints, + 91-132: 42 hand keypoints + + In total, we have 133 keypoints for wholebody pose estimation. + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict( + from_file='configs/_base_/datasets/coco_wholebody.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw COCO annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + # COCO-Wholebody: consisting of body, foot, face and hand keypoints + _keypoints = np.array(ann['keypoints'] + ann['foot_kpts'] + + ann['face_kpts'] + ann['lefthand_kpts'] + + ann['righthand_kpts']).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2] > 0) + + num_keypoints = ann['num_keypoints'] + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'segmentation': ann['segmentation'], + 'id': ann['id'], + 'category_id': ann['category_id'], + # store the raw annotation of the instance + # it is useful for evaluation without providing ann_file + 'raw_ann_info': copy.deepcopy(ann), + } + + return data_info diff --git a/mmpose/datasets/datasets/wholebody/halpe_dataset.py b/mmpose/datasets/datasets/wholebody/halpe_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0699f3b7023b200ee42e3cfe7f475a51123ef190 --- /dev/null +++ b/mmpose/datasets/datasets/wholebody/halpe_dataset.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.registry import DATASETS +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class HalpeDataset(BaseCocoStyleDataset): + """Halpe dataset for pose estimation. + + 'https://github.com/Fang-Haoshu/Halpe-FullBody' + + Halpe keypoints:: + + 0-19: 20 body keypoints, + 20-25: 6 foot keypoints, + 26-93: 68 face keypoints, + 94-135: 42 hand keypoints + + In total, we have 136 keypoints for wholebody pose estimation. + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/halpe.py') diff --git a/mmpose/datasets/samplers.py b/mmpose/datasets/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..d6bb34287a8c6b43552601eeb9b2e7c9a4fa90df --- /dev/null +++ b/mmpose/datasets/samplers.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import math +from typing import Iterator, List, Optional, Sized, Union + +import torch +from mmengine.dist import get_dist_info, sync_random_seed +from torch.utils.data import Sampler + +from mmpose.datasets import CombinedDataset +from mmpose.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class MultiSourceSampler(Sampler): + """Multi-Source Sampler. According to the sampling ratio, sample data from + different datasets to form batches. + + Args: + dataset (Sized): The dataset + batch_size (int): Size of mini-batch + source_ratio (list[int | float]): The sampling ratio of different + source datasets in a mini-batch + shuffle (bool): Whether shuffle the dataset or not. Defaults to + ``True`` + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + seed (int, optional): Random seed. If ``None``, set a random seed. + Defaults to ``None`` + """ + + def __init__(self, + dataset: Sized, + batch_size: int, + source_ratio: List[Union[int, float]], + shuffle: bool = True, + round_up: bool = True, + seed: Optional[int] = None) -> None: + + assert isinstance(dataset, CombinedDataset),\ + f'The dataset must be CombinedDataset, but get {dataset}' + assert isinstance(batch_size, int) and batch_size > 0, \ + 'batch_size must be a positive integer value, ' \ + f'but got batch_size={batch_size}' + assert isinstance(source_ratio, list), \ + f'source_ratio must be a list, but got source_ratio={source_ratio}' + assert len(source_ratio) == len(dataset._lens), \ + 'The length of source_ratio must be equal to ' \ + f'the number of datasets, but got source_ratio={source_ratio}' + + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.cumulative_sizes = [0] + list(itertools.accumulate(dataset._lens)) + self.batch_size = batch_size + self.source_ratio = source_ratio + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size)) + self.num_per_source = [ + int(batch_size * sr / sum(source_ratio)) for sr in source_ratio + ] + self.num_per_source[0] = batch_size - sum(self.num_per_source[1:]) + + assert sum(self.num_per_source) == batch_size, \ + 'The sum of num_per_source must be equal to ' \ + f'batch_size, but get {self.num_per_source}' + + self.seed = sync_random_seed() if seed is None else seed + self.shuffle = shuffle + self.round_up = round_up + self.source2inds = { + source: self._indices_of_rank(len(ds)) + for source, ds in enumerate(dataset.datasets) + } + + def _infinite_indices(self, sample_size: int) -> Iterator[int]: + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(sample_size, generator=g).tolist() + else: + yield from torch.arange(sample_size).tolist() + + def _indices_of_rank(self, sample_size: int) -> Iterator[int]: + """Slice the infinite indices by rank.""" + yield from itertools.islice( + self._infinite_indices(sample_size), self.rank, None, + self.world_size) + + def __iter__(self) -> Iterator[int]: + batch_buffer = [] + num_iters = self.num_samples // self.batch_size + if self.round_up and self.num_samples > num_iters * self.batch_size: + num_iters += 1 + for i in range(num_iters): + for source, num in enumerate(self.num_per_source): + batch_buffer_per_source = [] + for idx in self.source2inds[source]: + idx += self.cumulative_sizes[source] + batch_buffer_per_source.append(idx) + if len(batch_buffer_per_source) == num: + batch_buffer += batch_buffer_per_source + break + return iter(batch_buffer) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """Compatible in `epoch-based runner.""" + pass diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61dae74b8c6e756a87151c2ef8ae2423eb507515 --- /dev/null +++ b/mmpose/datasets/transforms/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bottomup_transforms import (BottomupGetHeatmapMask, BottomupRandomAffine, + BottomupResize) +from .common_transforms import (Albumentation, GenerateTarget, + GetBBoxCenterScale, PhotometricDistortion, + RandomBBoxTransform, RandomFlip, + RandomHalfBody) +from .converting import KeypointConverter +from .formatting import PackPoseInputs +from .loading import LoadImage +from .topdown_transforms import TopdownAffine + +__all__ = [ + 'GetBBoxCenterScale', 'RandomBBoxTransform', 'RandomFlip', + 'RandomHalfBody', 'TopdownAffine', 'Albumentation', + 'PhotometricDistortion', 'PackPoseInputs', 'LoadImage', + 'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize', + 'GenerateTarget', 'KeypointConverter' +] diff --git a/mmpose/datasets/transforms/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25f01623b630515298da7ef1e525c23d7d40927b Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/__pycache__/bottomup_transforms.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/bottomup_transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76ffe18189f729a5019a1a6f143e7df5387f68d0 Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/bottomup_transforms.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/__pycache__/common_transforms.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/common_transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ccf52a667cbdc6f3e958b5dfecf0bb3c225c342 Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/common_transforms.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/__pycache__/converting.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/converting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bc65203f98e4a7cfe327d59bfd9c08a3dea83c2 Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/converting.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/__pycache__/formatting.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/formatting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bf2b35e6a5b2e170dccbed1983248a441e4efa4 Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/formatting.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/__pycache__/loading.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/loading.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3368dc2771aafe618300b2ea832ffb6206eaa8f0 Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/loading.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/__pycache__/topdown_transforms.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/topdown_transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d214bbd6c54fc3192f67f4756cada2ad338af5 Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/topdown_transforms.cpython-38.pyc differ diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c31e0ae17d65f074a89c56210285184f2eeedc0b --- /dev/null +++ b/mmpose/datasets/transforms/bottomup_transforms.py @@ -0,0 +1,517 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import xtcocotools.mask as cocomask +from mmcv.image import imflip_, imresize +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness +from scipy.stats import truncnorm + +from mmpose.registry import TRANSFORMS +from mmpose.structures.bbox import get_udp_warp_matrix, get_warp_matrix + + +@TRANSFORMS.register_module() +class BottomupGetHeatmapMask(BaseTransform): + """Generate the mask of valid regions from the segmentation annotation. + + Required Keys: + + - img_shape + - invalid_segs (optional) + - warp_mat (optional) + - flip (optional) + - flip_direction (optional) + - heatmaps (optional) + + Added Keys: + + - heatmap_mask + """ + + def _segs_to_mask(self, segs: list, img_shape: Tuple[int, + int]) -> np.ndarray: + """Calculate mask from object segmentations. + + Args: + segs (List): The object segmentation annotations in COCO format + img_shape (Tuple): The image shape in (h, w) + + Returns: + np.ndarray: The binary object mask in size (h, w), where the + object pixels are 1 and background pixels are 0 + """ + + # RLE is a simple yet efficient format for storing binary masks. + # details can be found at `COCO tools `__ + rles = [] + for seg in segs: + rle = cocomask.frPyObjects(seg, img_shape[0], img_shape[1]) + if isinstance(rle, list): + # For non-crowded objects (e.g. human with no visible + # keypoints), the results is a list of rles + rles.extend(rle) + else: + # For crowded objects, the result is a single rle + rles.append(rle) + + if rles: + mask = cocomask.decode(cocomask.merge(rles)) + else: + mask = np.zeros(img_shape, dtype=np.uint8) + + return mask + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`BottomupGetHeatmapMask` to perform + photometric distortion on images. + + See ``transform()`` method of :class:`BaseTransform` for details. + + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + invalid_segs = results.get('invalid_segs', []) + img_shape = results['img_shape'] # (img_h, img_w) + input_size = results['input_size'] + + # Calculate the mask of the valid region by negating the segmentation + # mask of invalid objects + mask = 1 - self._segs_to_mask(invalid_segs, img_shape) + + # Apply an affine transform to the mask if the image has been + # transformed + if 'warp_mat' in results: + warp_mat = results['warp_mat'] + + mask = mask.astype(np.float32) + mask = cv2.warpAffine( + mask, warp_mat, input_size, flags=cv2.INTER_LINEAR) + + # Flip the mask if the image has been flipped + if results.get('flip', False): + flip_dir = results['flip_direction'] + if flip_dir is not None: + mask = imflip_(mask, flip_dir) + + # Resize the mask to the same size of heatmaps + if 'heatmaps' in results: + heatmaps = results['heatmaps'] + if isinstance(heatmaps, list): + # Multi-level heatmaps + heatmap_mask = [] + for hm in results['heatmaps']: + h, w = hm.shape[1:3] + _mask = imresize( + mask, size=(w, h), interpolation='bilinear') + heatmap_mask.append(_mask) + else: + h, w = heatmaps.shape[1:3] + heatmap_mask = imresize( + mask, size=(w, h), interpolation='bilinear') + else: + heatmap_mask = mask + + # Binarize the mask(s) + if isinstance(heatmap_mask, list): + results['heatmap_mask'] = [hm > 0.5 for hm in heatmap_mask] + else: + results['heatmap_mask'] = heatmap_mask > 0.5 + + return results + + +@TRANSFORMS.register_module() +class BottomupRandomAffine(BaseTransform): + r"""Randomly shift, resize and rotate the image. + + Required Keys: + + - img + - img_shape + - keypoints (optional) + + Modified Keys: + + - img + - keypoints (optional) + + Added Keys: + + - input_size + - warp_mat + + Args: + input_size (Tuple[int, int]): The input image size of the model in + [w, h] + shift_factor (float): Randomly shift the image in range + :math:`[-dx, dx]` and :math:`[-dy, dy]` in X and Y directions, + where :math:`dx(y) = img_w(h) \cdot shift_factor` in pixels. + Defaults to 0.2 + shift_prob (float): Probability of applying random shift. Defaults to + 1.0 + scale_factor (Tuple[float, float]): Randomly resize the image in range + :math:`[scale_factor[0], scale_factor[1]]`. Defaults to + (0.75, 1.5) + scale_prob (float): Probability of applying random resizing. Defaults + to 1.0 + scale_type (str): wrt ``long`` or ``short`` length of the image. + Defaults to ``short`` + rotate_factor (float): Randomly rotate the bbox in + :math:`[-rotate_factor, rotate_factor]` in degrees. Defaults + to 40.0 + use_udp (bool): Whether use unbiased data processing. See + `UDP (CVPR 2020)`_ for details. Defaults to ``False`` + + .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 + """ + + def __init__(self, + input_size: Tuple[int, int], + shift_factor: float = 0.2, + shift_prob: float = 1., + scale_factor: Tuple[float, float] = (0.75, 1.5), + scale_prob: float = 1., + scale_type: str = 'short', + rotate_factor: float = 30., + rotate_prob: float = 1, + use_udp: bool = False) -> None: + super().__init__() + + self.input_size = input_size + self.shift_factor = shift_factor + self.shift_prob = shift_prob + self.scale_factor = scale_factor + self.scale_prob = scale_prob + self.scale_type = scale_type + self.rotate_factor = rotate_factor + self.rotate_prob = rotate_prob + self.use_udp = use_udp + + @staticmethod + def _truncnorm(low: float = -1., + high: float = 1., + size: tuple = ()) -> np.ndarray: + """Sample from a truncated normal distribution.""" + return truncnorm.rvs(low, high, size=size).astype(np.float32) + + def _fix_aspect_ratio(self, scale: np.ndarray, aspect_ratio: float): + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = scale + if w > h * aspect_ratio: + if self.scale_type == 'long': + _w, _h = w, w / aspect_ratio + elif self.scale_type == 'short': + _w, _h = h * aspect_ratio, h + else: + raise ValueError(f'Unknown scale type: {self.scale_type}') + else: + if self.scale_type == 'short': + _w, _h = w, w / aspect_ratio + elif self.scale_type == 'long': + _w, _h = h * aspect_ratio, h + else: + raise ValueError(f'Unknown scale type: {self.scale_type}') + return np.array([_w, _h], dtype=scale.dtype) + + @cache_randomness + def _get_transform_params(self) -> Tuple: + """Get random transform parameters. + + Returns: + tuple: + - offset (np.ndarray): Image offset rate in shape (2, ) + - scale (np.ndarray): Image scaling rate factor in shape (1, ) + - rotate (np.ndarray): Image rotation degree in shape (1, ) + """ + # get offset + if np.random.rand() < self.shift_prob: + offset = self._truncnorm(size=(2, )) * self.shift_factor + else: + offset = np.zeros((2, ), dtype=np.float32) + + # get scale + if np.random.rand() < self.scale_prob: + scale_min, scale_max = self.scale_factor + scale = scale_min + (scale_max - scale_min) * ( + self._truncnorm(size=(1, )) + 1) / 2 + else: + scale = np.ones(1, dtype=np.float32) + + # get rotation + if np.random.rand() < self.rotate_prob: + rotate = self._truncnorm() * self.rotate_factor + else: + rotate = 0 + + return offset, scale, rotate + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`BottomupRandomAffine` to perform + photometric distortion on images. + + See ``transform()`` method of :class:`BaseTransform` for details. + + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img_h, img_w = results['img_shape'] + w, h = self.input_size + + offset_rate, scale_rate, rotate = self._get_transform_params() + offset = offset_rate * [img_w, img_h] + scale = scale_rate * [img_w, img_h] + # adjust the scale to match the target aspect ratio + scale = self._fix_aspect_ratio(scale, aspect_ratio=w / h) + + if self.use_udp: + center = np.array([(img_w - 1.0) / 2, (img_h - 1.0) / 2], + dtype=np.float32) + warp_mat = get_udp_warp_matrix( + center=center + offset, + scale=scale, + rot=rotate, + output_size=(w, h)) + else: + center = np.array([img_w / 2, img_h / 2], dtype=np.float32) + warp_mat = get_warp_matrix( + center=center + offset, + scale=scale, + rot=rotate, + output_size=(w, h)) + + # warp image and keypoints + results['img'] = cv2.warpAffine( + results['img'], warp_mat, (int(w), int(h)), flags=cv2.INTER_LINEAR) + + if 'keypoints' in results: + # Only transform (x, y) coordinates + results['keypoints'][..., :2] = cv2.transform( + results['keypoints'][..., :2], warp_mat) + + if 'bbox' in results: + bbox = np.tile(results['bbox'], 2).reshape(-1, 4, 2) + # corner order: left_top, left_bottom, right_top, right_bottom + bbox[:, 1:3, 0] = bbox[:, 0:2, 0] + results['bbox'] = cv2.transform(bbox, warp_mat).reshape(-1, 8) + + results['input_size'] = self.input_size + results['warp_mat'] = warp_mat + + return results + + +@TRANSFORMS.register_module() +class BottomupResize(BaseTransform): + """Resize the image to the input size of the model. Optionally, the image + can be resized to multiple sizes to build a image pyramid for multi-scale + inference. + + Required Keys: + + - img + - ori_shape + + Modified Keys: + + - img + - img_shape + + Added Keys: + + - input_size + - warp_mat + - aug_scale + + Args: + input_size (Tuple[int, int]): The input size of the model in [w, h]. + Note that the actually size of the resized image will be affected + by ``resize_mode`` and ``size_factor``, thus may not exactly equals + to the ``input_size`` + aug_scales (List[float], optional): The extra input scales for + multi-scale testing. If given, the input image will be resized + to different scales to build a image pyramid. And heatmaps from + all scales will be aggregated to make final prediction. Defaults + to ``None`` + size_factor (int): The actual input size will be ceiled to + a multiple of the `size_factor` value at both sides. + Defaults to 16 + resize_mode (str): The method to resize the image to the input size. + Options are: + + - ``'fit'``: The image will be resized according to the + relatively longer side with the aspect ratio kept. The + resized image will entirely fits into the range of the + input size + - ``'expand'``: The image will be resized according to the + relatively shorter side with the aspect ratio kept. The + resized image will exceed the given input size at the + longer side + use_udp (bool): Whether use unbiased data processing. See + `UDP (CVPR 2020)`_ for details. Defaults to ``False`` + + .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 + """ + + def __init__(self, + input_size: Tuple[int, int], + aug_scales: Optional[List[float]] = None, + size_factor: int = 32, + resize_mode: str = 'fit', + use_udp: bool = False): + super().__init__() + + self.input_size = input_size + self.aug_scales = aug_scales + self.resize_mode = resize_mode + self.size_factor = size_factor + self.use_udp = use_udp + + @staticmethod + def _ceil_to_multiple(size: Tuple[int, int], base: int): + """Ceil the given size (tuple of [w, h]) to a multiple of the base.""" + return tuple(int(np.ceil(s / base) * base) for s in size) + + def _get_input_size(self, img_size: Tuple[int, int], + input_size: Tuple[int, int]) -> Tuple: + """Calculate the actual input size (which the original image will be + resized to) and the padded input size (which the resized image will be + padded to, or which is the size of the model input). + + Args: + img_size (Tuple[int, int]): The original image size in [w, h] + input_size (Tuple[int, int]): The expected input size in [w, h] + + Returns: + tuple: + - actual_input_size (Tuple[int, int]): The target size to resize + the image + - padded_input_size (Tuple[int, int]): The target size to generate + the model input which will contain the resized image + """ + img_w, img_h = img_size + ratio = img_w / img_h + + if self.resize_mode == 'fit': + padded_input_size = self._ceil_to_multiple(input_size, + self.size_factor) + if padded_input_size != input_size: + raise ValueError( + 'When ``resize_mode==\'fit\', the input size (height and' + ' width) should be mulitples of the size_factor(' + f'{self.size_factor}) at all scales. Got invalid input ' + f'size {input_size}.') + + pad_w, pad_h = padded_input_size + rsz_w = min(pad_w, pad_h * ratio) + rsz_h = min(pad_h, pad_w / ratio) + actual_input_size = (rsz_w, rsz_h) + + elif self.resize_mode == 'expand': + _padded_input_size = self._ceil_to_multiple( + input_size, self.size_factor) + pad_w, pad_h = _padded_input_size + rsz_w = max(pad_w, pad_h * ratio) + rsz_h = max(pad_h, pad_w / ratio) + + actual_input_size = (rsz_w, rsz_h) + padded_input_size = self._ceil_to_multiple(actual_input_size, + self.size_factor) + + else: + raise ValueError(f'Invalid resize mode {self.resize_mode}') + + return actual_input_size, padded_input_size + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`BottomupResize` to perform + photometric distortion on images. + + See ``transform()`` method of :class:`BaseTransform` for details. + + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + img_h, img_w = results['ori_shape'] + w, h = self.input_size + + input_sizes = [(w, h)] + if self.aug_scales: + input_sizes += [(int(w * s), int(h * s)) for s in self.aug_scales] + + imgs = [] + for i, (_w, _h) in enumerate(input_sizes): + + actual_input_size, padded_input_size = self._get_input_size( + img_size=(img_w, img_h), input_size=(_w, _h)) + + if self.use_udp: + center = np.array([(img_w - 1.0) / 2, (img_h - 1.0) / 2], + dtype=np.float32) + scale = np.array([img_w, img_h], dtype=np.float32) + warp_mat = get_udp_warp_matrix( + center=center, + scale=scale, + rot=0, + output_size=actual_input_size) + else: + center = np.array([img_w / 2, img_h / 2], dtype=np.float32) + scale = np.array([ + img_w * padded_input_size[0] / actual_input_size[0], + img_h * padded_input_size[1] / actual_input_size[1] + ], + dtype=np.float32) + warp_mat = get_warp_matrix( + center=center, + scale=scale, + rot=0, + output_size=padded_input_size) + + _img = cv2.warpAffine( + img, warp_mat, padded_input_size, flags=cv2.INTER_LINEAR) + + imgs.append(_img) + + # Store the transform information w.r.t. the main input size + if i == 0: + results['img_shape'] = padded_input_size[::-1] + results['input_center'] = center + results['input_scale'] = scale + results['input_size'] = padded_input_size + + if self.aug_scales: + results['img'] = imgs + results['aug_scales'] = self.aug_scales + else: + results['img'] = imgs[0] + results['aug_scale'] = None + + return results diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8db0ff37c75faec94e6e0df15b9cf2b099db86c9 --- /dev/null +++ b/mmpose/datasets/transforms/common_transforms.py @@ -0,0 +1,1044 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from copy import deepcopy +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import mmengine +import numpy as np +from mmcv.image import imflip +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness +from mmengine import is_list_of +from mmengine.dist import get_dist_info +from scipy.stats import truncnorm + +from mmpose.codecs import * # noqa: F401, F403 +from mmpose.registry import KEYPOINT_CODECS, TRANSFORMS +from mmpose.structures.bbox import bbox_xyxy2cs, flip_bbox +from mmpose.structures.keypoint import flip_keypoints +from mmpose.utils.typing import MultiConfig + +try: + import albumentations +except ImportError: + albumentations = None + +Number = Union[int, float] + + +@TRANSFORMS.register_module() +class GetBBoxCenterScale(BaseTransform): + """Convert bboxes from [x, y, w, h] to center and scale. + + The center is the coordinates of the bbox center, and the scale is the + bbox width and height normalized by a scale factor. + + Required Keys: + + - bbox + + Added Keys: + + - bbox_center + - bbox_scale + + Args: + padding (float): The bbox padding scale that will be multilied to + `bbox_scale`. Defaults to 1.25 + """ + + def __init__(self, padding: float = 1.25) -> None: + super().__init__() + + self.padding = padding + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`GetBBoxCenterScale`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + if 'bbox_center' in results and 'bbox_scale' in results: + rank, _ = get_dist_info() + if rank == 0: + warnings.warn('Use the existing "bbox_center" and "bbox_scale"' + '. The padding will still be applied.') + results['bbox_scale'] *= self.padding + + else: + bbox = results['bbox'] + center, scale = bbox_xyxy2cs(bbox, padding=self.padding) + + results['bbox_center'] = center + results['bbox_scale'] = scale + + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(padding={self.padding})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomFlip(BaseTransform): + """Randomly flip the image, bbox and keypoints. + + Required Keys: + + - img + - img_shape + - flip_indices + - input_size (optional) + - bbox (optional) + - bbox_center (optional) + - keypoints (optional) + - keypoints_visible (optional) + - img_mask (optional) + + Modified Keys: + + - img + - bbox (optional) + - bbox_center (optional) + - keypoints (optional) + - keypoints_visible (optional) + - img_mask (optional) + + Added Keys: + + - flip + - flip_direction + + Args: + prob (float | list[float]): The flipping probability. If a list is + given, the argument `direction` should be a list with the same + length. And each element in `prob` indicates the flipping + probability of the corresponding one in ``direction``. Defaults + to 0.5 + direction (str | list[str]): The flipping direction. Options are + ``'horizontal'``, ``'vertical'`` and ``'diagonal'``. If a list is + is given, each data sample's flipping direction will be sampled + from a distribution determined by the argument ``prob``. Defaults + to ``'horizontal'``. + """ + + def __init__(self, + prob: Union[float, List[float]] = 0.5, + direction: Union[str, List[str]] = 'horizontal') -> None: + if isinstance(prob, list): + assert is_list_of(prob, float) + assert 0 <= sum(prob) <= 1 + elif isinstance(prob, float): + assert 0 <= prob <= 1 + else: + raise ValueError(f'probs must be float or list of float, but \ + got `{type(prob)}`.') + self.prob = prob + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError(f'direction must be either str or list of str, \ + but got `{type(direction)}`.') + self.direction = direction + + if isinstance(prob, list): + assert len(prob) == len(self.direction) + + @cache_randomness + def _choose_direction(self) -> str: + """Choose the flip direction according to `prob` and `direction`""" + if isinstance(self.direction, + List) and not isinstance(self.direction, str): + # None means non-flip + direction_list: list = list(self.direction) + [None] + elif isinstance(self.direction, str): + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.prob, list): + non_prob: float = 1 - sum(self.prob) + prob_list = self.prob + [non_prob] + elif isinstance(self.prob, float): + non_prob = 1. - self.prob + # exclude non-flip + single_ratio = self.prob / (len(direction_list) - 1) + prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob] + + cur_dir = np.random.choice(direction_list, p=prob_list) + + return cur_dir + + def transform(self, results: dict) -> dict: + """The transform function of :class:`RandomFlip`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + + flip_dir = self._choose_direction() + + if flip_dir is None: + results['flip'] = False + results['flip_direction'] = None + else: + results['flip'] = True + results['flip_direction'] = flip_dir + + h, w = results.get('input_size', results['img_shape']) + # flip image and mask + if isinstance(results['img'], list): + results['img'] = [ + imflip(img, direction=flip_dir) for img in results['img'] + ] + else: + results['img'] = imflip(results['img'], direction=flip_dir) + + if 'img_mask' in results: + results['img_mask'] = imflip( + results['img_mask'], direction=flip_dir) + + # flip bboxes + if results.get('bbox', None) is not None: + results['bbox'] = flip_bbox( + results['bbox'], + image_size=(w, h), + bbox_format='xyxy', + direction=flip_dir) + + if results.get('bbox_center', None) is not None: + results['bbox_center'] = flip_bbox( + results['bbox_center'], + image_size=(w, h), + bbox_format='center', + direction=flip_dir) + + # flip keypoints + if results.get('keypoints', None) is not None: + keypoints, keypoints_visible = flip_keypoints( + results['keypoints'], + results.get('keypoints_visible', None), + image_size=(w, h), + flip_indices=results['flip_indices'], + direction=flip_dir) + + results['keypoints'] = keypoints + results['keypoints_visible'] = keypoints_visible + + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'direction={self.direction})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomHalfBody(BaseTransform): + """Data augmentation with half-body transform that keeps only the upper or + lower body at random. + + Required Keys: + + - keypoints + - keypoints_visible + - upper_body_ids + - lower_body_ids + + Modified Keys: + + - bbox + - bbox_center + - bbox_scale + + Args: + min_total_keypoints (int): The minimum required number of total valid + keypoints of a person to apply half-body transform. Defaults to 8 + min_half_keypoints (int): The minimum required number of valid + half-body keypoints of a person to apply half-body transform. + Defaults to 2 + padding (float): The bbox padding scale that will be multilied to + `bbox_scale`. Defaults to 1.5 + prob (float): The probability to apply half-body transform when the + keypoint number meets the requirement. Defaults to 0.3 + """ + + def __init__(self, + min_total_keypoints: int = 9, + min_upper_keypoints: int = 2, + min_lower_keypoints: int = 3, + padding: float = 1.5, + prob: float = 0.3, + upper_prioritized_prob: float = 0.7) -> None: + super().__init__() + self.min_total_keypoints = min_total_keypoints + self.min_upper_keypoints = min_upper_keypoints + self.min_lower_keypoints = min_lower_keypoints + self.padding = padding + self.prob = prob + self.upper_prioritized_prob = upper_prioritized_prob + + def _get_half_body_bbox(self, keypoints: np.ndarray, + half_body_ids: List[int] + ) -> Tuple[np.ndarray, np.ndarray]: + """Get half-body bbox center and scale of a single instance. + + Args: + keypoints (np.ndarray): Keypoints in shape (K, D) + upper_body_ids (list): The list of half-body keypont indices + + Returns: + tuple: A tuple containing half-body bbox center and scale + - center: Center (x, y) of the bbox + - scale: Scale (w, h) of the bbox + """ + + selected_keypoints = keypoints[half_body_ids] + center = selected_keypoints.mean(axis=0)[:2] + + x1, y1 = selected_keypoints.min(axis=0) + x2, y2 = selected_keypoints.max(axis=0) + w = x2 - x1 + h = y2 - y1 + scale = np.array([w, h], dtype=center.dtype) * self.padding + + return center, scale + + @cache_randomness + def _random_select_half_body(self, keypoints_visible: np.ndarray, + upper_body_ids: List[int], + lower_body_ids: List[int] + ) -> List[Optional[List[int]]]: + """Randomly determine whether applying half-body transform and get the + half-body keyponit indices of each instances. + + Args: + keypoints_visible (np.ndarray, optional): The visibility of + keypoints in shape (N, K, 1). + upper_body_ids (list): The list of upper body keypoint indices + lower_body_ids (list): The list of lower body keypoint indices + + Returns: + list[list[int] | None]: The selected half-body keypoint indices + of each instance. ``None`` means not applying half-body transform. + """ + + half_body_ids = [] + + for visible in keypoints_visible: + if visible.sum() < self.min_total_keypoints: + indices = None + elif np.random.rand() > self.prob: + indices = None + else: + upper_valid_ids = [i for i in upper_body_ids if visible[i] > 0] + lower_valid_ids = [i for i in lower_body_ids if visible[i] > 0] + + num_upper = len(upper_valid_ids) + num_lower = len(lower_valid_ids) + + prefer_upper = np.random.rand() < self.upper_prioritized_prob + if (num_upper < self.min_upper_keypoints + and num_lower < self.min_lower_keypoints): + indices = None + elif num_lower < self.min_lower_keypoints: + indices = upper_valid_ids + elif num_upper < self.min_upper_keypoints: + indices = lower_valid_ids + else: + indices = ( + upper_valid_ids if prefer_upper else lower_valid_ids) + + half_body_ids.append(indices) + + return half_body_ids + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`HalfBodyTransform`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + + half_body_ids = self._random_select_half_body( + keypoints_visible=results['keypoints_visible'], + upper_body_ids=results['upper_body_ids'], + lower_body_ids=results['lower_body_ids']) + + bbox_center = [] + bbox_scale = [] + + for i, indices in enumerate(half_body_ids): + if indices is None: + bbox_center.append(results['bbox_center'][i]) + bbox_scale.append(results['bbox_scale'][i]) + else: + _center, _scale = self._get_half_body_bbox( + results['keypoints'][i], indices) + bbox_center.append(_center) + bbox_scale.append(_scale) + + results['bbox_center'] = np.stack(bbox_center) + results['bbox_scale'] = np.stack(bbox_scale) + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(min_total_keypoints={self.min_total_keypoints}, ' + repr_str += f'min_upper_keypoints={self.min_upper_keypoints}, ' + repr_str += f'min_lower_keypoints={self.min_lower_keypoints}, ' + repr_str += f'padding={self.padding}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'upper_prioritized_prob={self.upper_prioritized_prob})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomBBoxTransform(BaseTransform): + r"""Rnadomly shift, resize and rotate the bounding boxes. + + Required Keys: + + - bbox_center + - bbox_scale + + Modified Keys: + + - bbox_center + - bbox_scale + + Added Keys: + - bbox_rotation + + Args: + shift_factor (float): Randomly shift the bbox in range + :math:`[-dx, dx]` and :math:`[-dy, dy]` in X and Y directions, + where :math:`dx(y) = x(y)_scale \cdot shift_factor` in pixels. + Defaults to 0.16 + shift_prob (float): Probability of applying random shift. Defaults to + 0.3 + scale_factor (Tuple[float, float]): Randomly resize the bbox in range + :math:`[scale_factor[0], scale_factor[1]]`. Defaults to (0.5, 1.5) + scale_prob (float): Probability of applying random resizing. Defaults + to 1.0 + rotate_factor (float): Randomly rotate the bbox in + :math:`[-rotate_factor, rotate_factor]` in degrees. Defaults + to 80.0 + rotate_prob (float): Probability of applying random rotation. Defaults + to 0.6 + """ + + def __init__(self, + shift_factor: float = 0.16, + shift_prob: float = 0.3, + scale_factor: Tuple[float, float] = (0.5, 1.5), + scale_prob: float = 1.0, + rotate_factor: float = 80.0, + rotate_prob: float = 0.6) -> None: + super().__init__() + + self.shift_factor = shift_factor + self.shift_prob = shift_prob + self.scale_factor = scale_factor + self.scale_prob = scale_prob + self.rotate_factor = rotate_factor + self.rotate_prob = rotate_prob + + @staticmethod + def _truncnorm(low: float = -1., + high: float = 1., + size: tuple = ()) -> np.ndarray: + """Sample from a truncated normal distribution.""" + return truncnorm.rvs(low, high, size=size).astype(np.float32) + + @cache_randomness + def _get_transform_params(self, num_bboxes: int) -> Tuple: + """Get random transform parameters. + + Args: + num_bboxes (int): The number of bboxes + + Returns: + tuple: + - offset (np.ndarray): Offset factor of each bbox in shape (n, 2) + - scale (np.ndarray): Scaling factor of each bbox in shape (n, 1) + - rotate (np.ndarray): Rotation degree of each bbox in shape (n,) + """ + # Get shift parameters + offset = self._truncnorm(size=(num_bboxes, 2)) * self.shift_factor + offset = np.where( + np.random.rand(num_bboxes, 1) < self.shift_prob, offset, 0.) + + # Get scaling parameters + scale_min, scale_max = self.scale_factor + mu = (scale_max + scale_min) * 0.5 + sigma = (scale_max - scale_min) * 0.5 + scale = self._truncnorm(size=(num_bboxes, 1)) * sigma + mu + scale = np.where( + np.random.rand(num_bboxes, 1) < self.scale_prob, scale, 1.) + + # Get rotation parameters + rotate = self._truncnorm(size=(num_bboxes, )) * self.rotate_factor + rotate = np.where( + np.random.rand(num_bboxes) < self.rotate_prob, rotate, 0.) + + return offset, scale, rotate + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`RandomBboxTransform`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + bbox_scale = results['bbox_scale'] + num_bboxes = bbox_scale.shape[0] + + offset, scale, rotate = self._get_transform_params(num_bboxes) + + results['bbox_center'] += offset * bbox_scale + results['bbox_scale'] *= scale + results['bbox_rotation'] = rotate + + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(shift_prob={self.shift_prob}, ' + repr_str += f'shift_factor={self.shift_factor}, ' + repr_str += f'scale_prob={self.scale_prob}, ' + repr_str += f'scale_factor={self.scale_factor}, ' + repr_str += f'rotate_prob={self.rotate_prob}, ' + repr_str += f'rotate_factor={self.rotate_factor})' + return repr_str + + +@TRANSFORMS.register_module() +@avoid_cache_randomness +class Albumentation(BaseTransform): + """Albumentation augmentation (pixel-level transforms only). + + Adds custom pixel-level transformations from Albumentations library. + Please visit `https://albumentations.ai/docs/` + to get more information. + + Note: we only support pixel-level transforms. + Please visit `https://github.com/albumentations-team/` + `albumentations#pixel-level-transforms` + to get more information about pixel-level transforms. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + transforms (List[dict]): A list of Albumentation transforms. + An example of ``transforms`` is as followed: + .. code-block:: python + + [ + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + keymap (dict | None): key mapping from ``input key`` to + ``albumentation-style key``. + Defaults to None, which will use {'img': 'image'}. + """ + + def __init__(self, + transforms: List[dict], + keymap: Optional[dict] = None) -> None: + if albumentations is None: + raise RuntimeError('albumentations is not installed') + + self.transforms = transforms + + self.aug = albumentations.Compose( + [self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = { + 'img': 'image', + } + else: + self.keymap_to_albu = keymap + + def albu_builder(self, cfg: dict) -> albumentations: + """Import a module from albumentations. + + It resembles some of :func:`build_from_cfg` logic. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + + Returns: + albumentations.BasicTransform: The constructed transform object + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + rank, _ = get_dist_info() + if rank == 0 and not hasattr( + albumentations.augmentations.transforms, obj_type): + warnings.warn( + f'{obj_type} is not pixel-level transformations. ' + 'Please use with caution.') + obj_cls = getattr(albumentations, obj_type) + else: + raise TypeError(f'type must be a str, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + def transform(self, results: dict) -> dict: + """The transform function of :class:`Albumentation` to apply + albumentations transforms. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): Result dict from the data pipeline. + + Return: + dict: updated result dict. + """ + # map result dict to albumentations format + results_albu = {} + for k, v in self.keymap_to_albu.items(): + assert k in results, \ + f'The `{k}` is required to perform albumentations transforms' + results_albu[v] = results[k] + + # Apply albumentations transforms + results_albu = self.aug(**results_albu) + + # map the albu results back to the original format + for k, v in self.keymap_to_albu.items(): + results[k] = results_albu[v] + + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str + + +@TRANSFORMS.register_module() +class PhotometricDistortion(BaseTransform): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta: int = 32, + contrast_range: Sequence[Number] = (0.5, 1.5), + saturation_range: Sequence[Number] = (0.5, 1.5), + hue_delta: int = 18) -> None: + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + @cache_randomness + def _random_flags(self) -> Sequence[Number]: + """Generate the random flags for subsequent transforms. + + Returns: + Sequence[Number]: a sequence of numbers that indicate whether to + do the corresponding transforms. + """ + # contrast_mode == 0 --> do random contrast first + # contrast_mode == 1 --> do random contrast last + contrast_mode = np.random.randint(2) + # whether to apply brightness distortion + brightness_flag = np.random.randint(2) + # whether to apply contrast distortion + contrast_flag = np.random.randint(2) + # the mode to convert color from BGR to HSV + hsv_mode = np.random.randint(4) + # whether to apply channel swap + swap_flag = np.random.randint(2) + + # the beta in `self._convert` to be added to image array + # in brightness distortion + brightness_beta = np.random.uniform(-self.brightness_delta, + self.brightness_delta) + # the alpha in `self._convert` to be multiplied to image array + # in contrast distortion + contrast_alpha = np.random.uniform(self.contrast_lower, + self.contrast_upper) + # the alpha in `self._convert` to be multiplied to image array + # in saturation distortion to hsv-formatted img + saturation_alpha = np.random.uniform(self.saturation_lower, + self.saturation_upper) + # delta of hue to add to image array in hue distortion + hue_delta = np.random.randint(-self.hue_delta, self.hue_delta) + # the random permutation of channel order + swap_channel_order = np.random.permutation(3) + + return (contrast_mode, brightness_flag, contrast_flag, hsv_mode, + swap_flag, brightness_beta, contrast_alpha, saturation_alpha, + hue_delta, swap_channel_order) + + def _convert(self, + img: np.ndarray, + alpha: float = 1, + beta: float = 0) -> np.ndarray: + """Multiple with alpha and add beta with clip. + + Args: + img (np.ndarray): The image array. + alpha (float): The random multiplier. + beta (float): The random offset. + + Returns: + np.ndarray: The updated image array. + """ + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def transform(self, results: dict) -> dict: + """The transform function of :class:`PhotometricDistortion` to perform + photometric distortion on images. + + See ``transform()`` method of :class:`BaseTransform` for details. + + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + assert 'img' in results, '`img` is not found in results' + img = results['img'] + + (contrast_mode, brightness_flag, contrast_flag, hsv_mode, swap_flag, + brightness_beta, contrast_alpha, saturation_alpha, hue_delta, + swap_channel_order) = self._random_flags() + + # random brightness distortion + if brightness_flag: + img = self._convert(img, beta=brightness_beta) + + # contrast_mode == 0 --> do random contrast first + # contrast_mode == 1 --> do random contrast last + if contrast_mode == 1: + if contrast_flag: + img = self._convert(img, alpha=contrast_alpha) + + if hsv_mode: + # random saturation/hue distortion + img = mmcv.bgr2hsv(img) + if hsv_mode == 1 or hsv_mode == 3: + # apply saturation distortion to hsv-formatted img + img[:, :, 1] = self._convert( + img[:, :, 1], alpha=saturation_alpha) + if hsv_mode == 2 or hsv_mode == 3: + # apply hue distortion to hsv-formatted img + img[:, :, 0] = img[:, :, 0].astype(int) + hue_delta + img = mmcv.hsv2bgr(img) + + if contrast_mode == 1: + if contrast_flag: + img = self._convert(img, alpha=contrast_alpha) + + # randomly swap channels + if swap_flag: + img = img[..., swap_channel_order] + + results['img'] = img + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str + + +@TRANSFORMS.register_module() +class GenerateTarget(BaseTransform): + """Encode keypoints into Target. + + The generated target is usually the supervision signal of the model + learning, e.g. heatmaps or regression labels. + + Required Keys: + + - keypoints + - keypoints_visible + - dataset_keypoint_weights + + Added Keys: + + - The keys of the encoded items from the codec will be updated into + the results, e.g. ``'heatmaps'`` or ``'keypoint_weights'``. See + the specific codec for more details. + + Args: + encoder (dict | list[dict]): The codec config for keypoint encoding. + Both single encoder and multiple encoders (given as a list) are + supported + multilevel (bool): Determine the method to handle multiple encoders. + If ``multilevel==True``, generate multilevel targets from a group + of encoders of the same type (e.g. multiple :class:`MSRAHeatmap` + encoders with different sigma values); If ``multilevel==False``, + generate combined targets from a group of different encoders. This + argument will have no effect in case of single encoder. Defaults + to ``False`` + use_dataset_keypoint_weights (bool): Whether use the keypoint weights + from the dataset meta information. Defaults to ``False`` + target_type (str, deprecated): This argument is deprecated and has no + effect. Defaults to ``None`` + """ + + def __init__(self, + encoder: MultiConfig, + target_type: Optional[str] = None, + multilevel: bool = False, + use_dataset_keypoint_weights: bool = False) -> None: + super().__init__() + + if target_type is not None: + rank, _ = get_dist_info() + if rank == 0: + warnings.warn( + 'The argument `target_type` is deprecated in' + ' GenerateTarget. The target type and encoded ' + 'keys will be determined by encoder(s).', + DeprecationWarning) + + self.encoder_cfg = deepcopy(encoder) + self.multilevel = multilevel + self.use_dataset_keypoint_weights = use_dataset_keypoint_weights + + if isinstance(self.encoder_cfg, list): + self.encoder = [ + KEYPOINT_CODECS.build(cfg) for cfg in self.encoder_cfg + ] + else: + assert not self.multilevel, ( + 'Need multiple encoder configs if ``multilevel==True``') + self.encoder = KEYPOINT_CODECS.build(self.encoder_cfg) + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`GenerateTarget`. + + See ``transform()`` method of :class:`BaseTransform` for details. + """ + + if results.get('transformed_keypoints', None) is not None: + # use keypoints transformed by TopdownAffine + keypoints = results['transformed_keypoints'] + elif results.get('keypoints', None) is not None: + # use original keypoints + keypoints = results['keypoints'] + else: + raise ValueError( + 'GenerateTarget requires \'transformed_keypoints\' or' + ' \'keypoints\' in the results.') + + keypoints_visible = results['keypoints_visible'] + + # Encoded items from the encoder(s) will be updated into the results. + # Please refer to the document of the specific codec for details about + # encoded items. + if not isinstance(self.encoder, list): + # For single encoding, the encoded items will be directly added + # into results. + auxiliary_encode_kwargs = { + key: results[key] + for key in self.encoder.auxiliary_encode_keys + } + encoded = self.encoder.encode( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + **auxiliary_encode_kwargs) + + else: + encoded_list = [] + for _encoder in self.encoder: + auxiliary_encode_kwargs = { + key: results[key] + for key in _encoder.auxiliary_encode_keys + } + encoded_list.append( + _encoder.encode( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + **auxiliary_encode_kwargs)) + + if self.multilevel: + # For multilevel encoding, the encoded items from each encoder + # should have the same keys. + + keys = encoded_list[0].keys() + if not all(_encoded.keys() == keys + for _encoded in encoded_list): + raise ValueError( + 'Encoded items from all encoders must have the same ' + 'keys if ``multilevel==True``.') + + encoded = { + k: [_encoded[k] for _encoded in encoded_list] + for k in keys + } + + else: + # For combined encoding, the encoded items from different + # encoders should have no overlapping items, except for + # `keypoint_weights`. If multiple `keypoint_weights` are given, + # they will be multiplied as the final `keypoint_weights`. + + encoded = dict() + keypoint_weights = [] + + for _encoded in encoded_list: + for key, value in _encoded.items(): + if key == 'keypoint_weights': + keypoint_weights.append(value) + elif key not in encoded: + encoded[key] = value + else: + raise ValueError( + f'Overlapping item "{key}" from multiple ' + 'encoders, which is not supported when ' + '``multilevel==False``') + + if keypoint_weights: + encoded['keypoint_weights'] = keypoint_weights + + if self.use_dataset_keypoint_weights and 'keypoint_weights' in encoded: + if isinstance(encoded['keypoint_weights'], list): + for w in encoded['keypoint_weights']: + w *= results['dataset_keypoint_weights'] + else: + encoded['keypoint_weights'] *= results[ + 'dataset_keypoint_weights'] + + results.update(encoded) + + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += (f'(encoder={str(self.encoder_cfg)}, ') + repr_str += ('use_dataset_keypoint_weights=' + f'{self.use_dataset_keypoint_weights})') + return repr_str diff --git a/mmpose/datasets/transforms/converting.py b/mmpose/datasets/transforms/converting.py new file mode 100644 index 0000000000000000000000000000000000000000..38dcea09946eaafe396f4ba8e23cafa3d2da7ecc --- /dev/null +++ b/mmpose/datasets/transforms/converting.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import numpy as np +from mmcv.transforms import BaseTransform + +from mmpose.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class KeypointConverter(BaseTransform): + """Change the order of keypoints according to the given mapping. + + Required Keys: + + - keypoints + - keypoints_visible + + Modified Keys: + + - keypoints + - keypoints_visible + + Args: + num_keypoints (int): The number of keypoints in target dataset. + mapping (list): A list containing mapping indexes. Each element has + format (source_index, target_index) + + Example: + >>> import numpy as np + >>> # case 1: 1-to-1 mapping + >>> # (0, 0) means target[0] = source[0] + >>> self = KeypointConverter( + >>> num_keypoints=3, + >>> mapping=[ + >>> (0, 0), (1, 1), (2, 2), (3, 3) + >>> ]) + >>> results = dict( + >>> keypoints=np.arange(34).reshape(2, 3, 2), + >>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2) + >>> results = self(results) + >>> assert np.equal(results['keypoints'], + >>> np.arange(34).reshape(2, 3, 2)).all() + >>> assert np.equal(results['keypoints_visible'], + >>> np.arange(34).reshape(2, 3, 2) % 2).all() + >>> + >>> # case 2: 2-to-1 mapping + >>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2 + >>> self = KeypointConverter( + >>> num_keypoints=3, + >>> mapping=[ + >>> ((1, 2), 0), (1, 1), (2, 2) + >>> ]) + >>> results = dict( + >>> keypoints=np.arange(34).reshape(2, 3, 2), + >>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2) + >>> results = self(results) + """ + + def __init__(self, num_keypoints: int, + mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple, + int]]]): + self.num_keypoints = num_keypoints + self.mapping = mapping + source_index, target_index = zip(*mapping) + + src1, src2 = [], [] + interpolation = False + for x in source_index: + if isinstance(x, (list, tuple)): + assert len(x) == 2, 'source_index should be a list/tuple of ' \ + 'length 2' + src1.append(x[0]) + src2.append(x[1]) + interpolation = True + else: + src1.append(x) + src2.append(x) + + # When paired source_indexes are input, + # keep a self.source_index2 for interpolation + if interpolation: + self.source_index2 = src2 + + self.source_index = src1 + self.target_index = target_index + self.interpolation = interpolation + + def transform(self, results: dict) -> dict: + num_instances = results['keypoints'].shape[0] + + keypoints = np.zeros((num_instances, self.num_keypoints, 2)) + keypoints_visible = np.zeros((num_instances, self.num_keypoints)) + + # When paired source_indexes are input, + # perform interpolation with self.source_index and self.source_index2 + if self.interpolation: + keypoints[:, self.target_index] = 0.5 * ( + results['keypoints'][:, self.source_index] + + results['keypoints'][:, self.source_index2]) + + keypoints_visible[:, self.target_index] = results[ + 'keypoints_visible'][:, self.source_index] * \ + results['keypoints_visible'][:, self.source_index2] + else: + keypoints[:, + self.target_index] = results['keypoints'][:, self. + source_index] + keypoints_visible[:, self.target_index] = results[ + 'keypoints_visible'][:, self.source_index] + + results['keypoints'] = keypoints + results['keypoints_visible'] = keypoints_visible + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(num_keypoints={self.num_keypoints}, '\ + f'mapping={self.mapping})' + return repr_str diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9ad522f295d86daec05d91d95e567af6ad9878 --- /dev/null +++ b/mmpose/datasets/transforms/formatting.py @@ -0,0 +1,218 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Union + +import numpy as np +import torch +from mmcv.transforms import BaseTransform +from mmengine.structures import InstanceData, PixelData +from mmengine.utils import is_seq_of + +from mmpose.registry import TRANSFORMS +from mmpose.structures import MultilevelPixelData, PoseDataSample + + +def image_to_tensor(img: Union[np.ndarray, + Sequence[np.ndarray]]) -> torch.torch.Tensor: + """Translate image or sequence of images to tensor. Multiple image tensors + will be stacked. + + Args: + value (np.ndarray | Sequence[np.ndarray]): The original image or + image sequence + + Returns: + torch.Tensor: The output tensor. + """ + + if isinstance(img, np.ndarray): + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + + img = np.ascontiguousarray(img) + tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous() + else: + assert is_seq_of(img, np.ndarray) + tensor = torch.stack([image_to_tensor(_img) for _img in img]) + + return tensor + + +@TRANSFORMS.register_module() +class PackPoseInputs(BaseTransform): + """Pack the inputs data for pose estimation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default it includes: + + - ``id``: id of the data sample + + - ``img_id``: id of the image + + - ``'category_id'``: the id of the instance category + + - ``img_path``: path to the image file + + - ``crowd_index`` (optional): measure the crowding level of an image, + defined in CrowdPose dataset + + - ``ori_shape``: original shape of the image as a tuple (h, w, c) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``input_size``: the input size to the network + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + - ``flip_indices``: the indices of each keypoint's symmetric keypoint + + - ``raw_ann_info`` (optional): raw annotation of the instance(s) + + Args: + meta_keys (Sequence[str], optional): Meta keys which will be stored in + :obj: `PoseDataSample` as meta info. Defaults to ``('id', + 'img_id', 'img_path', 'category_id', 'crowd_index, 'ori_shape', + 'img_shape',, 'input_size', 'input_center', 'input_scale', 'flip', + 'flip_direction', 'flip_indices', 'raw_ann_info')`` + """ + + # items in `instance_mapping_table` will be directly packed into + # PoseDataSample.gt_instances without converting to Tensor + instance_mapping_table = { + 'bbox': 'bboxes', + 'head_size': 'head_size', + 'bbox_center': 'bbox_centers', + 'bbox_scale': 'bbox_scales', + 'bbox_score': 'bbox_scores', + 'keypoints': 'keypoints', + 'keypoints_visible': 'keypoints_visible', + } + + # items in `label_mapping_table` will be packed into + # PoseDataSample.gt_instance_labels and converted to Tensor. These items + # will be used for computing losses + label_mapping_table = { + 'keypoint_labels': 'keypoint_labels', + 'keypoint_x_labels': 'keypoint_x_labels', + 'keypoint_y_labels': 'keypoint_y_labels', + 'keypoint_weights': 'keypoint_weights', + 'instance_coords': 'instance_coords' + } + + # items in `field_mapping_table` will be packed into + # PoseDataSample.gt_fields and converted to Tensor. These items will be + # used for computing losses + field_mapping_table = { + 'heatmaps': 'heatmaps', + 'instance_heatmaps': 'instance_heatmaps', + 'heatmap_mask': 'heatmap_mask', + 'heatmap_weights': 'heatmap_weights', + 'displacements': 'displacements', + 'displacement_weights': 'displacement_weights', + } + + def __init__(self, + meta_keys=('id', 'img_id', 'img_path', 'category_id', + 'crowd_index', 'ori_shape', 'img_shape', + 'input_size', 'input_center', 'input_scale', + 'flip', 'flip_direction', 'flip_indices', + 'raw_ann_info'), + pack_transformed=False): + self.meta_keys = meta_keys + self.pack_transformed = pack_transformed + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_samples' (obj:`PoseDataSample`): The annotation info of the + sample. + """ + # Pack image(s) + if 'img' in results: + img = results['img'] + img_tensor = image_to_tensor(img) + + data_sample = PoseDataSample() + + # pack instance data + gt_instances = InstanceData() + for key, packed_key in self.instance_mapping_table.items(): + if key in results: + gt_instances.set_field(results[key], packed_key) + + # pack `transformed_keypoints` for visualizing data transform + # and augmentation results + if self.pack_transformed and 'transformed_keypoints' in results: + gt_instances.set_field(results['transformed_keypoints'], + 'transformed_keypoints') + + data_sample.gt_instances = gt_instances + + # pack instance labels + gt_instance_labels = InstanceData() + for key, packed_key in self.label_mapping_table.items(): + if key in results: + if isinstance(results[key], list): + # A list of labels is usually generated by combined + # multiple encoders (See ``GenerateTarget`` in + # mmpose/datasets/transforms/common_transforms.py) + # In this case, labels in list should have the same + # shape and will be stacked. + _labels = np.stack(results[key]) + gt_instance_labels.set_field(_labels, packed_key) + else: + gt_instance_labels.set_field(results[key], packed_key) + data_sample.gt_instance_labels = gt_instance_labels.to_tensor() + + # pack fields + gt_fields = None + for key, packed_key in self.field_mapping_table.items(): + if key in results: + if isinstance(results[key], list): + if gt_fields is None: + gt_fields = MultilevelPixelData() + else: + assert isinstance( + gt_fields, MultilevelPixelData + ), 'Got mixed single-level and multi-level pixel data.' + else: + if gt_fields is None: + gt_fields = PixelData() + else: + assert isinstance( + gt_fields, PixelData + ), 'Got mixed single-level and multi-level pixel data.' + + gt_fields.set_field(results[key], packed_key) + + if gt_fields: + data_sample.gt_fields = gt_fields.to_tensor() + + img_meta = {k: results[k] for k in self.meta_keys if k in results} + data_sample.set_metainfo(img_meta) + + packed_results = dict() + packed_results['inputs'] = img_tensor + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/mmpose/datasets/transforms/loading.py b/mmpose/datasets/transforms/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..28edcb48066b7a7b5534f69e6b3981d31825beca --- /dev/null +++ b/mmpose/datasets/transforms/loading.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np +from mmcv.transforms import LoadImageFromFile + +from mmpose.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class LoadImage(LoadImageFromFile): + """Load an image from file or from the np.ndarray in ``results['img']``. + + Required Keys: + + - img_path + - img (optional) + + Modified Keys: + + - img + - img_shape + - ori_shape + - img_path (optional) + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:``mmcv.imfrombytes``. + Defaults to 'color'. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :func:``mmcv.imfrombytes`` for details. + Defaults to 'cv2'. + backend_args (dict, optional): Arguments to instantiate the preifx of + uri corresponding backend. Defaults to None. + ignore_empty (bool): Whether to allow loading empty image or file path + not existent. Defaults to False. + """ + + def transform(self, results: dict) -> Optional[dict]: + """The transform function of :class:`LoadImage`. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + + if 'img' not in results: + # Load image from file by :meth:`LoadImageFromFile.transform` + results = super().transform(results) + else: + img = results['img'] + assert isinstance(img, np.ndarray) + if self.to_float32: + img = img.astype(np.float32) + + if 'img_path' not in results: + results['img_path'] = None + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + + return results diff --git a/mmpose/datasets/transforms/topdown_transforms.py b/mmpose/datasets/transforms/topdown_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..29aa48eb06576059c8fad4e1d46de4e5061f372f --- /dev/null +++ b/mmpose/datasets/transforms/topdown_transforms.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + +import cv2 +import numpy as np +from mmcv.transforms import BaseTransform +from mmengine import is_seq_of + +from mmpose.registry import TRANSFORMS +from mmpose.structures.bbox import get_udp_warp_matrix, get_warp_matrix + + +@TRANSFORMS.register_module() +class TopdownAffine(BaseTransform): + """Get the bbox image as the model input by affine transform. + + Required Keys: + + - img + - bbox_center + - bbox_scale + - bbox_rotation (optional) + - keypoints (optional) + + Modified Keys: + + - img + - bbox_scale + + Added Keys: + + - input_size + - transformed_keypoints + + Args: + input_size (Tuple[int, int]): The input image size of the model in + [w, h]. The bbox region will be cropped and resize to `input_size` + use_udp (bool): Whether use unbiased data processing. See + `UDP (CVPR 2020)`_ for details. Defaults to ``False`` + + .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 + """ + + def __init__(self, + input_size: Tuple[int, int], + use_udp: bool = False) -> None: + super().__init__() + + assert is_seq_of(input_size, int) and len(input_size) == 2, ( + f'Invalid input_size {input_size}') + + self.input_size = input_size + self.use_udp = use_udp + + @staticmethod + def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float): + """Reshape the bbox to a fixed aspect ratio. + + Args: + bbox_scale (np.ndarray): The bbox scales (w, h) in shape (n, 2) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.darray: The reshaped bbox scales in (n, 2) + """ + + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + def transform(self, results: Dict) -> Optional[dict]: + """The transform function of :class:`TopdownAffine`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + + w, h = self.input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + results['bbox_scale'] = self._fix_aspect_ratio( + results['bbox_scale'], aspect_ratio=w / h) + + # TODO: support multi-instance + assert results['bbox_center'].shape[0] == 1, ( + 'Top-down heatmap only supports single instance. Got invalid ' + f'shape of bbox_center {results["bbox_center"].shape}.') + + center = results['bbox_center'][0] + scale = results['bbox_scale'][0] + if 'bbox_rotation' in results: + rot = results['bbox_rotation'][0] + else: + rot = 0. + + if self.use_udp: + warp_mat = get_udp_warp_matrix( + center, scale, rot, output_size=(w, h)) + else: + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + if isinstance(results['img'], list): + results['img'] = [ + cv2.warpAffine( + img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + for img in results['img'] + ] + else: + results['img'] = cv2.warpAffine( + results['img'], warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + if results.get('keypoints', None) is not None: + transformed_keypoints = results['keypoints'].copy() + # Only transform (x, y) coordinates + transformed_keypoints[..., :2] = cv2.transform( + results['keypoints'][..., :2], warp_mat) + results['transformed_keypoints'] = transformed_keypoints + + results['input_size'] = (w, h) + + return results + + def __repr__(self) -> str: + """print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(input_size={self.input_size}, ' + repr_str += f'use_udp={self.use_udp})' + return repr_str diff --git a/mmpose/engine/__init__.py b/mmpose/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac85928986718cfc181ac311949c238ea11cf34c --- /dev/null +++ b/mmpose/engine/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401, F403 +from .optim_wrappers import * # noqa: F401, F403 diff --git a/mmpose/engine/__pycache__/__init__.cpython-38.pyc b/mmpose/engine/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d8cc2201d5a5be2e91e002781177270fbfc6793 Binary files /dev/null and b/mmpose/engine/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/engine/hooks/__init__.py b/mmpose/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dadb9c5f913e3e61eebb4ee4e246076bb3d45dd9 --- /dev/null +++ b/mmpose/engine/hooks/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ema_hook import ExpMomentumEMA +from .visualization_hook import PoseVisualizationHook + +__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA'] diff --git a/mmpose/engine/hooks/__pycache__/__init__.cpython-38.pyc b/mmpose/engine/hooks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ab45eaf91278f1165c44221818f8694926a65a8 Binary files /dev/null and b/mmpose/engine/hooks/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/engine/hooks/__pycache__/ema_hook.cpython-38.pyc b/mmpose/engine/hooks/__pycache__/ema_hook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f638feee9b91007eaa02ccdce18e6d5bca270020 Binary files /dev/null and b/mmpose/engine/hooks/__pycache__/ema_hook.cpython-38.pyc differ diff --git a/mmpose/engine/hooks/__pycache__/visualization_hook.cpython-38.pyc b/mmpose/engine/hooks/__pycache__/visualization_hook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9590b9dceaf9f8ae48bdb96438e73c465772526d Binary files /dev/null and b/mmpose/engine/hooks/__pycache__/visualization_hook.cpython-38.pyc differ diff --git a/mmpose/engine/hooks/ema_hook.py b/mmpose/engine/hooks/ema_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1a689f96f49c33059ec1e4afbe7b01b85164f9 --- /dev/null +++ b/mmpose/engine/hooks/ema_hook.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import ExponentialMovingAverage +from torch import Tensor + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class ExpMomentumEMA(ExponentialMovingAverage): + """Exponential moving average (EMA) with exponential momentum strategy, + which is used in YOLOX. + + Ported from ` the implementation of MMDetection + `_. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The momentum used for updating ema parameter. + Ema's parameter are updated with the formula: + `averaged_param = (1-momentum) * averaged_param + momentum * + source_param`. Defaults to 0.0002. + gamma (int): Use a larger momentum early in training and gradually + annealing to a smaller value to update the ema model smoothly. The + momentum is calculated as + `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. + Defaults to 2000. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.0002, + gamma: int = 2000, + interval=1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' + self.gamma = gamma + + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int) -> None: + """Compute the moving average of the parameters using the exponential + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + """ + momentum = (1 - self.momentum) * math.exp( + -float(1 + steps) / self.gamma) + self.momentum + averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) diff --git a/mmpose/engine/hooks/visualization_hook.py b/mmpose/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..24b845f282291bfd60d2a1a507ce42d586a35073 --- /dev/null +++ b/mmpose/engine/hooks/visualization_hook.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import warnings +from typing import Optional, Sequence + +import mmcv +import mmengine +import mmengine.fileio as fileio +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.visualization import Visualizer + +from mmpose.registry import HOOKS +from mmpose.structures import PoseDataSample, merge_data_samples + + +@HOOKS.register_module() +class PoseVisualizationHook(Hook): + """Pose Estimation Visualization Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + 2. If ``out_dir`` is specified, it means that the prediction results + need to be saved to ``out_dir``. In order to avoid vis_backends + also storing data, so ``vis_backends`` needs to be excluded. + 3. ``vis_backends`` takes effect if the user does not specify ``show`` + and `out_dir``. You can set ``vis_backends`` to WandbVisBackend or + TensorboardVisBackend to store the prediction result in Wandb or + Tensorboard. + + Args: + enable (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_dir (str, optional): directory where painted images + will be saved in testing process. + backend_args (dict, optional): Arguments to instantiate the preifx of + uri corresponding backend. Defaults to None. + """ + + def __init__( + self, + enable: bool = False, + interval: int = 50, + kpt_thr: float = 0.3, + show: bool = False, + wait_time: float = 0., + out_dir: Optional[str] = None, + backend_args: Optional[dict] = None, + ): + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.kpt_thr = kpt_thr + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.enable = enable + self.out_dir = out_dir + self._test_index = 0 + self.backend_args = backend_args + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[PoseDataSample]) -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. + """ + if self.enable is False: + return + + self._visualizer.set_dataset_meta(runner.val_evaluator.dataset_meta) + + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = data_batch['data_samples'][0].get('img_path') + img_bytes = fileio.get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + data_sample = outputs[0] + + # revert the heatmap on the original image + data_sample = merge_data_samples([data_sample]) + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + os.path.basename(img_path) if self.show else 'val_img', + img, + data_sample=data_sample, + draw_gt=False, + draw_bbox=True, + draw_heatmap=True, + show=self.show, + wait_time=self.wait_time, + kpt_thr=self.kpt_thr, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[PoseDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. + """ + if self.enable is False: + return + + if self.out_dir is not None: + self.out_dir = os.path.join(runner.work_dir, runner.timestamp, + self.out_dir) + mmengine.mkdir_or_exist(self.out_dir) + + self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta) + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.get('img_path') + img_bytes = fileio.get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + data_sample = merge_data_samples([data_sample]) + + out_file = None + if self.out_dir is not None: + out_file_name, postfix = os.path.basename(img_path).rsplit( + '.', 1) + index = len([ + fname for fname in os.listdir(self.out_dir) + if fname.startswith(out_file_name) + ]) + out_file = f'{out_file_name}_{index}.{postfix}' + out_file = os.path.join(self.out_dir, out_file) + + self._visualizer.add_datasample( + os.path.basename(img_path) if self.show else 'test_img', + img, + data_sample=data_sample, + show=self.show, + draw_gt=False, + draw_bbox=True, + draw_heatmap=True, + wait_time=self.wait_time, + kpt_thr=self.kpt_thr, + out_file=out_file, + step=self._test_index) diff --git a/mmpose/engine/optim_wrappers/__init__.py b/mmpose/engine/optim_wrappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0b1f533ae88e0d9914bda0bd3e79532c779339 --- /dev/null +++ b/mmpose/engine/optim_wrappers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_decay_optim_wrapper import LayerDecayOptimWrapperConstructor + +__all__ = ['LayerDecayOptimWrapperConstructor'] diff --git a/mmpose/engine/optim_wrappers/__pycache__/__init__.cpython-38.pyc b/mmpose/engine/optim_wrappers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6409ffa2e0dadc364d2902061eef3ec985f6dcb7 Binary files /dev/null and b/mmpose/engine/optim_wrappers/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/engine/optim_wrappers/__pycache__/layer_decay_optim_wrapper.cpython-38.pyc b/mmpose/engine/optim_wrappers/__pycache__/layer_decay_optim_wrapper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb88ac3eb2a65d1220b46853498dcd540a049e21 Binary files /dev/null and b/mmpose/engine/optim_wrappers/__pycache__/layer_decay_optim_wrapper.cpython-38.pyc differ diff --git a/mmpose/engine/optim_wrappers/layer_decay_optim_wrapper.py b/mmpose/engine/optim_wrappers/layer_decay_optim_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6513e5593d98e9aa77a2795529ddeb538b6099c3 --- /dev/null +++ b/mmpose/engine/optim_wrappers/layer_decay_optim_wrapper.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dist.utils import get_dist_info +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +def get_num_layer_for_vit(var_name, num_max_layer): + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('backbone.layers'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + else: + return num_max_layer - 1 + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module(force=True) +class LayerDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): + + def __init__(self, optim_wrapper_cfg, paramwise_cfg=None): + super().__init__(optim_wrapper_cfg, paramwise_cfg=None) + self.layer_decay_rate = paramwise_cfg.get('layer_decay_rate', 0.5) + + super().__init__(optim_wrapper_cfg, paramwise_cfg) + + def add_params(self, params, module, prefix='', lr=None): + parameter_groups = {} + print(self.paramwise_cfg) + num_layers = self.paramwise_cfg.get('num_layers') + 2 + layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') + weight_decay = self.base_wd + + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if (len(param.shape) == 1 or name.endswith('.bias') + or 'pos_embed' in name): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + layer_id = get_num_layer_for_vit(name, num_layers) + group_name = 'layer_%d_%s' % (layer_id, group_name) + + if group_name not in parameter_groups: + scale = layer_decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + params.extend(parameter_groups.values()) diff --git a/mmpose/evaluation/__init__.py b/mmpose/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f70dc226d30f7b8e4ee5a44ca163ad1ae04eabf5 --- /dev/null +++ b/mmpose/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .functional import * # noqa: F401,F403 +from .metrics import * # noqa: F401,F403 diff --git a/mmpose/evaluation/__pycache__/__init__.cpython-38.pyc b/mmpose/evaluation/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3626954e238738013228b15cddb5f4003cc8542 Binary files /dev/null and b/mmpose/evaluation/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/evaluation/functional/__init__.py b/mmpose/evaluation/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c4a8b5d1ebde25eb71d7c82e138b1309af617f3 --- /dev/null +++ b/mmpose/evaluation/functional/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .keypoint_eval import (keypoint_auc, keypoint_epe, keypoint_nme, + keypoint_pck_accuracy, + multilabel_classification_accuracy, + pose_pck_accuracy, simcc_pck_accuracy) +from .nms import nms, oks_nms, soft_oks_nms + +__all__ = [ + 'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_nme', 'keypoint_epe', + 'pose_pck_accuracy', 'multilabel_classification_accuracy', + 'simcc_pck_accuracy', 'nms', 'oks_nms', 'soft_oks_nms' +] diff --git a/mmpose/evaluation/functional/__pycache__/__init__.cpython-38.pyc b/mmpose/evaluation/functional/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2213fd51049355be16f1f57c86488ac91255913 Binary files /dev/null and b/mmpose/evaluation/functional/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/evaluation/functional/__pycache__/keypoint_eval.cpython-38.pyc b/mmpose/evaluation/functional/__pycache__/keypoint_eval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cbb7d04a06c5d7b2c4fee75cdcaff54a5fd0f11 Binary files /dev/null and b/mmpose/evaluation/functional/__pycache__/keypoint_eval.cpython-38.pyc differ diff --git a/mmpose/evaluation/functional/__pycache__/nms.cpython-38.pyc b/mmpose/evaluation/functional/__pycache__/nms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e227dd3ce213c549ad10f315c4ebc24aa780419 Binary files /dev/null and b/mmpose/evaluation/functional/__pycache__/nms.cpython-38.pyc differ diff --git a/mmpose/evaluation/functional/keypoint_eval.py b/mmpose/evaluation/functional/keypoint_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..060243357b39f9bc7f4689b6ba59320bd86b9d4b --- /dev/null +++ b/mmpose/evaluation/functional/keypoint_eval.py @@ -0,0 +1,320 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import numpy as np + +from mmpose.codecs.utils import get_heatmap_maximum, get_simcc_maximum + + +def _calc_distances(preds: np.ndarray, gts: np.ndarray, mask: np.ndarray, + norm_factor: np.ndarray) -> np.ndarray: + """Calculate the normalized distances between preds and target. + + Note: + - instance number: N + - keypoint number: K + - keypoint dimension: D (normally, D=2 or D=3) + + Args: + preds (np.ndarray[N, K, D]): Predicted keypoint location. + gts (np.ndarray[N, K, D]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + norm_factor (np.ndarray[N, D]): Normalization factor. + Typical value is heatmap_size. + + Returns: + np.ndarray[K, N]: The normalized distances. \ + If target keypoints are missing, the distance is -1. + """ + N, K, _ = preds.shape + # set mask=0 when norm_factor==0 + _mask = mask.copy() + _mask[np.where((norm_factor == 0).sum(1))[0], :] = False + + distances = np.full((N, K), -1, dtype=np.float32) + # handle invalid values + norm_factor[np.where(norm_factor <= 0)] = 1e6 + distances[_mask] = np.linalg.norm( + ((preds - gts) / norm_factor[:, None, :])[_mask], axis=-1) + return distances.T + + +def _distance_acc(distances: np.ndarray, thr: float = 0.5) -> float: + """Return the percentage below the distance threshold, while ignoring + distances values with -1. + + Note: + - instance number: N + + Args: + distances (np.ndarray[N, ]): The normalized distances. + thr (float): Threshold of the distances. + + Returns: + float: Percentage of distances below the threshold. \ + If all target keypoints are missing, return -1. + """ + distance_valid = distances != -1 + num_distance_valid = distance_valid.sum() + if num_distance_valid > 0: + return (distances[distance_valid] < thr).sum() / num_distance_valid + return -1 + + +def keypoint_pck_accuracy(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray, + thr: np.ndarray, norm_factor: np.ndarray) -> tuple: + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - instance number: N + - keypoint number: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. + norm_factor (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - acc (np.ndarray[K]): Accuracy of each keypoint. + - avg_acc (float): Averaged accuracy across all keypoints. + - cnt (int): Number of valid keypoints. + """ + distances = _calc_distances(pred, gt, mask, norm_factor) + acc = np.array([_distance_acc(d, thr) for d in distances]) + valid_acc = acc[acc >= 0] + cnt = len(valid_acc) + avg_acc = valid_acc.mean() if cnt > 0 else 0 + return acc, avg_acc, cnt + + +def keypoint_auc(pred: np.ndarray, + gt: np.ndarray, + mask: np.ndarray, + norm_factor: np.ndarray, + num_thrs: int = 20) -> float: + """Calculate the Area under curve (AUC) of keypoint PCK accuracy. + + Note: + - instance number: N + - keypoint number: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + norm_factor (float): Normalization factor. + num_thrs (int): number of thresholds to calculate auc. + + Returns: + float: Area under curve (AUC) of keypoint PCK accuracy. + """ + nor = np.tile(np.array([[norm_factor, norm_factor]]), (pred.shape[0], 1)) + thrs = [1.0 * i / num_thrs for i in range(num_thrs)] + avg_accs = [] + for thr in thrs: + _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor) + avg_accs.append(avg_acc) + + auc = 0 + for i in range(num_thrs): + auc += 1.0 / num_thrs * avg_accs[i] + return auc + + +def keypoint_nme(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray, + normalize_factor: np.ndarray) -> float: + """Calculate the normalized mean error (NME). + + Note: + - instance number: N + - keypoint number: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize_factor (np.ndarray[N, 2]): Normalization factor. + + Returns: + float: normalized mean error + """ + distances = _calc_distances(pred, gt, mask, normalize_factor) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +def keypoint_epe(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray) -> float: + """Calculate the end-point error. + + Note: + - instance number: N + - keypoint number: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + + Returns: + float: Average end-point error. + """ + + distances = _calc_distances( + pred, gt, mask, + np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32)) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +def pose_pck_accuracy(output: np.ndarray, + target: np.ndarray, + mask: np.ndarray, + thr: float = 0.05, + normalize: Optional[np.ndarray] = None) -> tuple: + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints from heatmaps. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + output (np.ndarray[N, K, H, W]): Model output heatmaps. + target (np.ndarray[N, K, H, W]): Groundtruth heatmaps. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. Default 0.05. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - np.ndarray[K]: Accuracy of each keypoint. + - float: Averaged accuracy across all keypoints. + - int: Number of valid keypoints. + """ + N, K, H, W = output.shape + if K == 0: + return None, 0, 0 + if normalize is None: + normalize = np.tile(np.array([[H, W]]), (N, 1)) + + pred, _ = get_heatmap_maximum(output) + gt, _ = get_heatmap_maximum(target) + return keypoint_pck_accuracy(pred, gt, mask, thr, normalize) + + +def simcc_pck_accuracy(output: Tuple[np.ndarray, np.ndarray], + target: Tuple[np.ndarray, np.ndarray], + simcc_split_ratio: float, + mask: np.ndarray, + thr: float = 0.05, + normalize: Optional[np.ndarray] = None) -> tuple: + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints from SimCC. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - instance number: N + - keypoint number: K + + Args: + output (Tuple[np.ndarray, np.ndarray]): Model predicted SimCC. + target (Tuple[np.ndarray, np.ndarray]): Groundtruth SimCC. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. Default 0.05. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - np.ndarray[K]: Accuracy of each keypoint. + - float: Averaged accuracy across all keypoints. + - int: Number of valid keypoints. + """ + pred_x, pred_y = output + gt_x, gt_y = target + + N, _, Wx = pred_x.shape + _, _, Wy = pred_y.shape + W, H = int(Wx / simcc_split_ratio), int(Wy / simcc_split_ratio) + + if normalize is None: + normalize = np.tile(np.array([[H, W]]), (N, 1)) + + pred_coords, _ = get_simcc_maximum(pred_x, pred_y) + pred_coords /= simcc_split_ratio + gt_coords, _ = get_simcc_maximum(gt_x, gt_y) + gt_coords /= simcc_split_ratio + + return keypoint_pck_accuracy(pred_coords, gt_coords, mask, thr, normalize) + + +def multilabel_classification_accuracy(pred: np.ndarray, + gt: np.ndarray, + mask: np.ndarray, + thr: float = 0.5) -> float: + """Get multi-label classification accuracy. + + Note: + - batch size: N + - label number: L + + Args: + pred (np.ndarray[N, L, 2]): model predicted labels. + gt (np.ndarray[N, L, 2]): ground-truth labels. + mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of + ground-truth labels. + thr (float): Threshold for calculating accuracy. + + Returns: + float: multi-label classification accuracy. + """ + # we only compute accuracy on the samples with ground-truth of all labels. + valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0) + pred, gt = pred[valid], gt[valid] + + if pred.shape[0] == 0: + acc = 0.0 # when no sample is with gt labels, set acc to 0. + else: + # The classification of a sample is regarded as correct + # only if it's correct for all labels. + acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean() + return acc diff --git a/mmpose/evaluation/functional/nms.py b/mmpose/evaluation/functional/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..eed4e5cf736585667657a1d4a4ca34f1bc8c0423 --- /dev/null +++ b/mmpose/evaluation/functional/nms.py @@ -0,0 +1,327 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch +# and https://github.com/HRNet/DEKR +# Original licence: Copyright (c) Microsoft, under the MIT License. +# ------------------------------------------------------------------------------ + +from typing import List, Optional + +import numpy as np + + +def nms(dets: np.ndarray, thr: float) -> List[int]: + """Greedily select boxes with high confidence and overlap <= thr. + + Args: + dets (np.ndarray): [[x1, y1, x2, y2, score]]. + thr (float): Retain overlap < thr. + + Returns: + list: Indexes to keep. + """ + if len(dets) == 0: + return [] + + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while len(order) > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thr)[0] + order = order[inds + 1] + + return keep + + +def oks_iou(g: np.ndarray, + d: np.ndarray, + a_g: float, + a_d: np.ndarray, + sigmas: Optional[np.ndarray] = None, + vis_thr: Optional[float] = None) -> np.ndarray: + """Calculate oks ious. + + Note: + + - number of keypoints: K + - number of instances: N + + Args: + g (np.ndarray): The instance to calculate OKS IOU with other + instances. Containing the keypoints coordinates. Shape: (K*3, ) + d (np.ndarray): The rest instances. Containing the keypoints + coordinates. Shape: (N, K*3) + a_g (float): Area of the ground truth object. + a_d (np.ndarray): Area of the detected object. Shape: (N, ) + sigmas (np.ndarray, optional): Keypoint labelling uncertainty. + Please refer to `COCO keypoint evaluation + `__ for more details. + If not given, use the sigmas on COCO dataset. + If specified, shape: (K, ). Defaults to ``None`` + vis_thr(float, optional): Threshold of the keypoint visibility. + If specified, will calculate OKS based on those keypoints whose + visibility higher than vis_thr. If not given, calculate the OKS + based on all keypoints. Defaults to ``None`` + + Returns: + np.ndarray: The oks ious. + """ + if sigmas is None: + sigmas = np.array([ + .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, + .87, .87, .89, .89 + ]) / 10.0 + vars = (sigmas * 2)**2 + xg = g[0::3] + yg = g[1::3] + vg = g[2::3] + ious = np.zeros(len(d), dtype=np.float32) + for n_d in range(0, len(d)): + xd = d[n_d, 0::3] + yd = d[n_d, 1::3] + vd = d[n_d, 2::3] + dx = xd - xg + dy = yd - yg + e = (dx**2 + dy**2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2 + if vis_thr is not None: + ind = list((vg > vis_thr) & (vd > vis_thr)) + e = e[ind] + ious[n_d] = np.sum(np.exp(-e)) / len(e) if len(e) != 0 else 0.0 + return ious + + +def oks_nms(kpts_db: List[dict], + thr: float, + sigmas: Optional[np.ndarray] = None, + vis_thr: Optional[float] = None, + score_per_joint: bool = False): + """OKS NMS implementations. + + Args: + kpts_db (List[dict]): The keypoints results of the same image. + thr (float): The threshold of NMS. Will retain oks overlap < thr. + sigmas (np.ndarray, optional): Keypoint labelling uncertainty. + Please refer to `COCO keypoint evaluation + `__ for more details. + If not given, use the sigmas on COCO dataset. Defaults to ``None`` + vis_thr(float, optional): Threshold of the keypoint visibility. + If specified, will calculate OKS based on those keypoints whose + visibility higher than vis_thr. If not given, calculate the OKS + based on all keypoints. Defaults to ``None`` + score_per_joint(bool): Whether the input scores (in kpts_db) are + per-joint scores. Defaults to ``False`` + + Returns: + np.ndarray: indexes to keep. + """ + if len(kpts_db) == 0: + return [] + + if score_per_joint: + scores = np.array([k['score'].mean() for k in kpts_db]) + else: + scores = np.array([k['score'] for k in kpts_db]) + + kpts = np.array([k['keypoints'].flatten() for k in kpts_db]) + areas = np.array([k['area'] for k in kpts_db]) + + order = scores.argsort()[::-1] + + keep = [] + while len(order) > 0: + i = order[0] + keep.append(i) + + oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], + sigmas, vis_thr) + + inds = np.where(oks_ovr <= thr)[0] + order = order[inds + 1] + + keep = np.array(keep) + + return keep + + +def _rescore(overlap: np.ndarray, + scores: np.ndarray, + thr: float, + type: str = 'gaussian'): + """Rescoring mechanism gaussian or linear. + + Args: + overlap (np.ndarray): The calculated oks ious. + scores (np.ndarray): target scores. + thr (float): retain oks overlap < thr. + type (str): The rescoring type. Could be 'gaussian' or 'linear'. + Defaults to ``'gaussian'`` + + Returns: + np.ndarray: indexes to keep + """ + assert len(overlap) == len(scores) + assert type in ['gaussian', 'linear'] + + if type == 'linear': + inds = np.where(overlap >= thr)[0] + scores[inds] = scores[inds] * (1 - overlap[inds]) + else: + scores = scores * np.exp(-overlap**2 / thr) + + return scores + + +def soft_oks_nms(kpts_db: List[dict], + thr: float, + max_dets: int = 20, + sigmas: Optional[np.ndarray] = None, + vis_thr: Optional[float] = None, + score_per_joint: bool = False): + """Soft OKS NMS implementations. + + Args: + kpts_db (List[dict]): The keypoints results of the same image. + thr (float): The threshold of NMS. Will retain oks overlap < thr. + max_dets (int): Maximum number of detections to keep. Defaults to 20 + sigmas (np.ndarray, optional): Keypoint labelling uncertainty. + Please refer to `COCO keypoint evaluation + `__ for more details. + If not given, use the sigmas on COCO dataset. Defaults to ``None`` + vis_thr(float, optional): Threshold of the keypoint visibility. + If specified, will calculate OKS based on those keypoints whose + visibility higher than vis_thr. If not given, calculate the OKS + based on all keypoints. Defaults to ``None`` + score_per_joint(bool): Whether the input scores (in kpts_db) are + per-joint scores. Defaults to ``False`` + + Returns: + np.ndarray: indexes to keep. + """ + if len(kpts_db) == 0: + return [] + + if score_per_joint: + scores = np.array([k['score'].mean() for k in kpts_db]) + else: + scores = np.array([k['score'] for k in kpts_db]) + + kpts = np.array([k['keypoints'].flatten() for k in kpts_db]) + areas = np.array([k['area'] for k in kpts_db]) + + order = scores.argsort()[::-1] + scores = scores[order] + + keep = np.zeros(max_dets, dtype=np.intp) + keep_cnt = 0 + while len(order) > 0 and keep_cnt < max_dets: + i = order[0] + + oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], + sigmas, vis_thr) + + order = order[1:] + scores = _rescore(oks_ovr, scores[1:], thr) + + tmp = scores.argsort()[::-1] + order = order[tmp] + scores = scores[tmp] + + keep[keep_cnt] = i + keep_cnt += 1 + + keep = keep[:keep_cnt] + + return keep + + +def nearby_joints_nms( + kpts_db: List[dict], + dist_thr: float, + num_nearby_joints_thr: Optional[int] = None, + score_per_joint: bool = False, + max_dets: int = 30, +): + """Nearby joints NMS implementations. Instances with non-maximum scores + will be suppressed if they have too much closed joints with other + instances. This function is modified from project + `DEKR`. + + Args: + kpts_db (list[dict]): keypoints and scores. + dist_thr (float): threshold for judging whether two joints are close. + num_nearby_joints_thr (int): threshold for judging whether two + instances are close. + max_dets (int): max number of detections to keep. + score_per_joint (bool): the input scores (in kpts_db) are per joint + scores. + + Returns: + np.ndarray: indexes to keep. + """ + + assert dist_thr > 0, '`dist_thr` must be greater than 0.' + if len(kpts_db) == 0: + return [] + + if score_per_joint: + scores = np.array([k['score'].mean() for k in kpts_db]) + else: + scores = np.array([k['score'] for k in kpts_db]) + + kpts = np.array([k['keypoints'] for k in kpts_db]) + + num_people, num_joints, _ = kpts.shape + if num_nearby_joints_thr is None: + num_nearby_joints_thr = num_joints // 2 + assert num_nearby_joints_thr < num_joints, '`num_nearby_joints_thr` must '\ + 'be less than the number of joints.' + + # compute distance threshold + pose_area = kpts.max(axis=1) - kpts.min(axis=1) + pose_area = np.sqrt(np.power(pose_area, 2).sum(axis=1)) + pose_area = pose_area.reshape(num_people, 1, 1) + pose_area = np.tile(pose_area, (num_people, num_joints)) + close_dist_thr = pose_area * dist_thr + + # count nearby joints between instances + instance_dist = kpts[:, None] - kpts + instance_dist = np.sqrt(np.power(instance_dist, 2).sum(axis=3)) + close_instance_num = (instance_dist < close_dist_thr).sum(2) + close_instance = close_instance_num > num_nearby_joints_thr + + # apply nms + ignored_pose_inds, keep_pose_inds = set(), list() + indexes = np.argsort(scores)[::-1] + for i in indexes: + if i in ignored_pose_inds: + continue + keep_inds = close_instance[i].nonzero()[0] + keep_ind = keep_inds[np.argmax(scores[keep_inds])] + if keep_ind not in ignored_pose_inds: + keep_pose_inds.append(keep_ind) + ignored_pose_inds = ignored_pose_inds.union(set(keep_inds)) + + # limit the number of output instances + if max_dets > 0 and len(keep_pose_inds) > max_dets: + sub_inds = np.argsort(scores[keep_pose_inds])[-1:-max_dets - 1:-1] + keep_pose_inds = [keep_pose_inds[i] for i in sub_inds] + + return keep_pose_inds diff --git a/mmpose/evaluation/metrics/__init__.py b/mmpose/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f02c353ef7d3d973a5ab3fa88128fa0814c6a1c7 --- /dev/null +++ b/mmpose/evaluation/metrics/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coco_metric import CocoMetric +from .coco_wholebody_metric import CocoWholeBodyMetric +from .keypoint_2d_metrics import (AUC, EPE, NME, JhmdbPCKAccuracy, + MpiiPCKAccuracy, PCKAccuracy) +from .keypoint_partition_metric import KeypointPartitionMetric +from .posetrack18_metric import PoseTrack18Metric + +__all__ = [ + 'CocoMetric', 'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'AUC', + 'EPE', 'NME', 'PoseTrack18Metric', 'CocoWholeBodyMetric', + 'KeypointPartitionMetric' +] diff --git a/mmpose/evaluation/metrics/__pycache__/__init__.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987a926b3e38850427b129f0ea66b45fe96bae52 Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/evaluation/metrics/__pycache__/coco_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/coco_metric.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2915fcb40b533eff868d71d71348e28de59fd1f5 Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/coco_metric.cpython-38.pyc differ diff --git a/mmpose/evaluation/metrics/__pycache__/coco_wholebody_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/coco_wholebody_metric.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6127cdb481d081ab5a5e41d8d5bed1a2ace39225 Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/coco_wholebody_metric.cpython-38.pyc differ diff --git a/mmpose/evaluation/metrics/__pycache__/keypoint_2d_metrics.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/keypoint_2d_metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5f9ef8c87b51e8c35d4b8849b9c2fea2c6be6c6 Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/keypoint_2d_metrics.cpython-38.pyc differ diff --git a/mmpose/evaluation/metrics/__pycache__/keypoint_partition_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/keypoint_partition_metric.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce8be0eca9a1d86bda8566ebdd6ff6ca18a19d0 Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/keypoint_partition_metric.cpython-38.pyc differ diff --git a/mmpose/evaluation/metrics/__pycache__/posetrack18_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/posetrack18_metric.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66b861698ee1776f390beb79b0898666257ed862 Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/posetrack18_metric.cpython-38.pyc differ diff --git a/mmpose/evaluation/metrics/coco_metric.py b/mmpose/evaluation/metrics/coco_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..8327e2eca7b978f66f3329fa1f45c3619e528c89 --- /dev/null +++ b/mmpose/evaluation/metrics/coco_metric.py @@ -0,0 +1,550 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import os.path as osp +import tempfile +from collections import OrderedDict, defaultdict +from typing import Dict, Optional, Sequence + +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.fileio import dump, get_local_path, load +from mmengine.logging import MMLogger +from xtcocotools.coco import COCO +from xtcocotools.cocoeval import COCOeval + +from mmpose.registry import METRICS +from ..functional import oks_nms, soft_oks_nms + + +@METRICS.register_module() +class CocoMetric(BaseMetric): + """COCO pose estimation task evaluation metric. + + Evaluate AR, AP, and mAP for keypoint detection tasks. Support COCO + dataset and other datasets in COCO format. Please refer to + `COCO keypoint evaluation `__ + for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None + use_area (bool): Whether to use ``'area'`` message in the annotations. + If the ground truth annotations (e.g. CrowdPose, AIC) do not have + the field ``'area'``, please set ``use_area=False``. + Defaults to ``True`` + iou_type (str): The same parameter as `iouType` in + :class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or + ``'keypoints_crowd'`` (used in CrowdPose dataset). + Defaults to ``'keypoints'`` + score_mode (str): The mode to score the prediction results which + should be one of the following options: + + - ``'bbox'``: Take the score of bbox as the score of the + prediction results. + - ``'bbox_keypoint'``: Use keypoint score to rescore the + prediction results. + - ``'bbox_rle'``: Use rle_score to rescore the + prediction results. + + Defaults to ``'bbox_keypoint'` + keypoint_score_thr (float): The threshold of keypoint score. The + keypoints with score lower than it will not be included to + rescore the prediction results. Valid only when ``score_mode`` is + ``bbox_keypoint``. Defaults to ``0.2`` + nms_mode (str): The mode to perform Non-Maximum Suppression (NMS), + which should be one of the following options: + + - ``'oks_nms'``: Use Object Keypoint Similarity (OKS) to + perform NMS. + - ``'soft_oks_nms'``: Use Object Keypoint Similarity (OKS) + to perform soft NMS. + - ``'none'``: Do not perform NMS. Typically for bottomup mode + output. + + Defaults to ``'oks_nms'` + nms_thr (float): The Object Keypoint Similarity (OKS) threshold + used in NMS when ``nms_mode`` is ``'oks_nms'`` or + ``'soft_oks_nms'``. Will retain the prediction results with OKS + lower than ``nms_thr``. Defaults to ``0.9`` + format_only (bool): Whether only format the output results without + doing quantitative evaluation. This is designed for the need of + test submission when the ground truth annotations are absent. If + set to ``True``, ``outfile_prefix`` should specify the path to + store the output results. Defaults to ``False`` + outfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., ``'a/b/prefix'``. + If not specified, a temp file will be created. Defaults to ``None`` + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Defaults to ``'cpu'`` + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Defaults to ``None`` + """ + default_prefix: Optional[str] = 'coco' + + def __init__(self, + ann_file: Optional[str] = None, + use_area: bool = True, + iou_type: str = 'keypoints', + score_mode: str = 'bbox_keypoint', + keypoint_score_thr: float = 0.2, + nms_mode: str = 'oks_nms', + nms_thr: float = 0.9, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.ann_file = ann_file + # initialize coco helper with the annotation json file + # if ann_file is not specified, initialize with the converted dataset + if ann_file is not None: + with get_local_path(ann_file) as local_path: + self.coco = COCO(local_path) + else: + self.coco = None + + self.use_area = use_area + self.iou_type = iou_type + + allowed_score_modes = ['bbox', 'bbox_keypoint', 'bbox_rle', 'keypoint'] + if score_mode not in allowed_score_modes: + raise ValueError( + "`score_mode` should be one of 'bbox', 'bbox_keypoint', " + f"'bbox_rle', but got {score_mode}") + self.score_mode = score_mode + self.keypoint_score_thr = keypoint_score_thr + + allowed_nms_modes = ['oks_nms', 'soft_oks_nms', 'none'] + if nms_mode not in allowed_nms_modes: + raise ValueError( + "`nms_mode` should be one of 'oks_nms', 'soft_oks_nms', " + f"'none', but got {nms_mode}") + self.nms_mode = nms_mode + self.nms_thr = nms_thr + + if format_only: + assert outfile_prefix is not None, '`outfile_prefix` can not be '\ + 'None when `format_only` is True, otherwise the result file '\ + 'will be saved to a temp directory which will be cleaned up '\ + 'in the end.' + elif ann_file is not None: + # do evaluation only if the ground truth annotations exist + assert 'annotations' in load(ann_file), \ + 'Ground truth annotations are required for evaluation '\ + 'when `format_only` is False.' + + self.format_only = format_only + self.outfile_prefix = outfile_prefix + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model, each of which has the following keys: + + - 'id': The id of the sample + - 'img_id': The image_id of the sample + - 'pred_instances': The prediction results of instance(s) + """ + for data_sample in data_samples: + if 'pred_instances' not in data_sample: + raise ValueError( + '`pred_instances` are required to process the ' + f'predictions results in {self.__class__.__name__}. ') + + # keypoints.shape: [N, K, 2], + # N: number of instances, K: number of keypoints + # for topdown-style output, N is usually 1, while for + # bottomup-style output, N is the number of instances in the image + keypoints = data_sample['pred_instances']['keypoints'] + # [N, K], the scores for all keypoints of all instances + keypoint_scores = data_sample['pred_instances']['keypoint_scores'] + assert keypoint_scores.shape == keypoints.shape[:2] + + # parse prediction results + pred = dict() + pred['id'] = data_sample['id'] + pred['img_id'] = data_sample['img_id'] + pred['keypoints'] = keypoints + pred['keypoint_scores'] = keypoint_scores + pred['category_id'] = data_sample.get('category_id', 1) + + if 'bbox_scores' in data_sample['pred_instances']: + # some one-stage models will predict bboxes and scores + # together with keypoints + bbox_scores = data_sample['pred_instances']['bbox_scores'] + elif ('bbox_scores' not in data_sample['gt_instances'] + or len(data_sample['gt_instances']['bbox_scores']) != + len(keypoints)): + # bottom-up models might output different number of + # instances from annotation + bbox_scores = np.ones(len(keypoints)) + else: + # top-down models use detected bboxes, the scores of which + # are contained in the gt_instances + bbox_scores = data_sample['gt_instances']['bbox_scores'] + pred['bbox_scores'] = bbox_scores + + # get area information + if 'bbox_scales' in data_sample['gt_instances']: + pred['areas'] = np.prod( + data_sample['gt_instances']['bbox_scales'], axis=1) + + # parse gt + gt = dict() + if self.coco is None: + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + gt['img_id'] = data_sample['img_id'] + if self.iou_type == 'keypoints_crowd': + assert 'crowd_index' in data_sample, \ + '`crowd_index` is required when `self.iou_type` is ' \ + '`keypoints_crowd`' + gt['crowd_index'] = data_sample['crowd_index'] + assert 'raw_ann_info' in data_sample, \ + 'The row ground truth annotations are required for ' \ + 'evaluation when `ann_file` is not provided' + anns = data_sample['raw_ann_info'] + gt['raw_ann_info'] = anns if isinstance(anns, list) else [anns] + + # add converted result to the results list + self.results.append((pred, gt)) + + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> str: + """Convert ground truth to coco format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. Each dict + contains the ground truth information about the data sample. + Required keys of the each `gt_dict` in `gt_dicts`: + - `img_id`: image id of the data sample + - `width`: original image width + - `height`: original image height + - `raw_ann_info`: the raw annotation information + Optional keys: + - `crowd_index`: measure the crowding level of an image, + defined in CrowdPose dataset + It is worth mentioning that, in order to compute `CocoMetric`, + there are some required keys in the `raw_ann_info`: + - `id`: the id to distinguish different annotations + - `image_id`: the image id of this annotation + - `category_id`: the category of the instance. + - `bbox`: the object bounding box + - `keypoints`: the keypoints cooridinates along with their + visibilities. Note that it need to be aligned + with the official COCO format, e.g., a list with length + N * 3, in which N is the number of keypoints. And each + triplet represent the [x, y, visible] of the keypoint. + - `iscrowd`: indicating whether the annotation is a crowd. + It is useful when matching the detection results to + the ground truth. + There are some optional keys as well: + - `area`: it is necessary when `self.use_area` is `True` + - `num_keypoints`: it is necessary when `self.iou_type` + is set as `keypoints_crowd`. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + Returns: + str: The filename of the json file. + """ + image_infos = [] + annotations = [] + img_ids = [] + ann_ids = [] + + for gt_dict in gt_dicts: + # filter duplicate image_info + if gt_dict['img_id'] not in img_ids: + image_info = dict( + id=gt_dict['img_id'], + width=gt_dict['width'], + height=gt_dict['height'], + ) + if self.iou_type == 'keypoints_crowd': + image_info['crowdIndex'] = gt_dict['crowd_index'] + + image_infos.append(image_info) + img_ids.append(gt_dict['img_id']) + + # filter duplicate annotations + for ann in gt_dict['raw_ann_info']: + if ann is None: + # during evaluation on bottom-up datasets, some images + # do not have instance annotation + continue + + annotation = dict( + id=ann['id'], + image_id=ann['image_id'], + category_id=ann['category_id'], + bbox=ann['bbox'], + keypoints=ann['keypoints'], + iscrowd=ann['iscrowd'], + ) + if self.use_area: + assert 'area' in ann, \ + '`area` is required when `self.use_area` is `True`' + annotation['area'] = ann['area'] + + if self.iou_type == 'keypoints_crowd': + assert 'num_keypoints' in ann, \ + '`num_keypoints` is required when `self.iou_type` ' \ + 'is `keypoints_crowd`' + annotation['num_keypoints'] = ann['num_keypoints'] + + annotations.append(annotation) + ann_ids.append(ann['id']) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmpose CocoMetric.') + coco_json = dict( + info=info, + images=image_infos, + categories=self.dataset_meta['CLASSES'], + licenses=None, + annotations=annotations, + ) + converted_json_path = f'{outfile_prefix}.gt.json' + dump(coco_json, converted_json_path, sort_keys=True, indent=4) + return converted_json_path + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # split prediction and gt list + preds, gts = zip(*results) + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + if self.coco is None: + # use converted gt json file to initialize coco helper + logger.info('Converting ground truth to coco format...') + coco_json_path = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=outfile_prefix) + self.coco = COCO(coco_json_path) + + kpts = defaultdict(list) + + # group the preds by img_id + for pred in preds: + img_id = pred['img_id'] + for idx in range(len(pred['keypoints'])): + instance = { + 'id': pred['id'], + 'img_id': pred['img_id'], + 'category_id': pred['category_id'], + 'keypoints': pred['keypoints'][idx], + 'keypoint_scores': pred['keypoint_scores'][idx], + 'bbox_score': pred['bbox_scores'][idx], + } + + if 'areas' in pred: + instance['area'] = pred['areas'][idx] + else: + # use keypoint to calculate bbox and get area + keypoints = pred['keypoints'][idx] + area = ( + np.max(keypoints[:, 0]) - np.min(keypoints[:, 0])) * ( + np.max(keypoints[:, 1]) - np.min(keypoints[:, 1])) + instance['area'] = area + + kpts[img_id].append(instance) + + # sort keypoint results according to id and remove duplicate ones + kpts = self._sort_and_unique_bboxes(kpts, key='id') + + # score the prediction results according to `score_mode` + # and perform NMS according to `nms_mode` + valid_kpts = defaultdict(list) + num_keypoints = self.dataset_meta['num_keypoints'] + for img_id, instances in kpts.items(): + for instance in instances: + # concatenate the keypoint coordinates and scores + instance['keypoints'] = np.concatenate([ + instance['keypoints'], instance['keypoint_scores'][:, None] + ], + axis=-1) + if self.score_mode == 'bbox': + instance['score'] = instance['bbox_score'] + elif self.score_mode == 'keypoint': + instance['score'] = np.mean(instance['keypoint_scores']) + else: + bbox_score = instance['bbox_score'] + if self.score_mode == 'bbox_rle': + keypoint_scores = instance['keypoint_scores'] + instance['score'] = float(bbox_score + + np.mean(keypoint_scores) + + np.max(keypoint_scores)) + + else: # self.score_mode == 'bbox_keypoint': + mean_kpt_score = 0 + valid_num = 0 + for kpt_idx in range(num_keypoints): + kpt_score = instance['keypoint_scores'][kpt_idx] + if kpt_score > self.keypoint_score_thr: + mean_kpt_score += kpt_score + valid_num += 1 + if valid_num != 0: + mean_kpt_score /= valid_num + instance['score'] = bbox_score * mean_kpt_score + # perform nms + if self.nms_mode == 'none': + valid_kpts[img_id] = instances + else: + nms = oks_nms if self.nms_mode == 'oks_nms' else soft_oks_nms + keep = nms( + instances, + self.nms_thr, + sigmas=self.dataset_meta['sigmas']) + valid_kpts[img_id] = [instances[_keep] for _keep in keep] + + # convert results to coco style and dump into a json file + self.results2json(valid_kpts, outfile_prefix=outfile_prefix) + + # only format the results without doing quantitative evaluation + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return {} + + # evaluation results + eval_results = OrderedDict() + logger.info(f'Evaluating {self.__class__.__name__}...') + info_str = self._do_python_keypoint_eval(outfile_prefix) + name_value = OrderedDict(info_str) + eval_results.update(name_value) + + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results + + def results2json(self, keypoints: Dict[int, list], + outfile_prefix: str) -> str: + """Dump the keypoint detection results to a COCO style json file. + + Args: + keypoints (Dict[int, list]): Keypoint detection results + of the dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json", + + Returns: + str: The json file name of keypoint results. + """ + # the results with category_id + cat_results = [] + + for _, img_kpts in keypoints.items(): + _keypoints = np.array( + [img_kpt['keypoints'] for img_kpt in img_kpts]) + num_keypoints = self.dataset_meta['num_keypoints'] + # collect all the person keypoints in current image + _keypoints = _keypoints.reshape(-1, num_keypoints * 3) + + result = [{ + 'image_id': img_kpt['img_id'], + 'category_id': img_kpt['category_id'], + 'keypoints': keypoint.tolist(), + 'score': float(img_kpt['score']), + } for img_kpt, keypoint in zip(img_kpts, _keypoints)] + + cat_results.extend(result) + + res_file = f'{outfile_prefix}.keypoints.json' + dump(cat_results, res_file, sort_keys=True, indent=4) + + def _do_python_keypoint_eval(self, outfile_prefix: str) -> list: + """Do keypoint evaluation using COCOAPI. + + Args: + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json", + + Returns: + list: a list of tuples. Each tuple contains the evaluation stats + name and corresponding stats value. + """ + res_file = f'{outfile_prefix}.keypoints.json' + coco_det = self.coco.loadRes(res_file) + sigmas = self.dataset_meta['sigmas'] + coco_eval = COCOeval(self.coco, coco_det, self.iou_type, sigmas, + self.use_area) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + if self.iou_type == 'keypoints_crowd': + stats_names = [ + 'AP', 'AP .5', 'AP .75', 'AR', 'AR .5', 'AR .75', 'AP(E)', + 'AP(M)', 'AP(H)' + ] + else: + stats_names = [ + 'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', + 'AR .75', 'AR (M)', 'AR (L)' + ] + + info_str = list(zip(stats_names, coco_eval.stats)) + + return info_str + + def _sort_and_unique_bboxes(self, + kpts: Dict[int, list], + key: str = 'id') -> Dict[int, list]: + """Sort keypoint detection results in each image and remove the + duplicate ones. Usually performed in multi-batch testing. + + Args: + kpts (Dict[int, list]): keypoint prediction results. The keys are + '`img_id`' and the values are list that may contain + keypoints of multiple persons. Each element in the list is a + dict containing the ``'key'`` field. + See the argument ``key`` for details. + key (str): The key name in each person prediction results. The + corresponding value will be used for sorting the results. + Default: ``'id'``. + + Returns: + Dict[int, list]: The sorted keypoint detection results. + """ + for img_id, persons in kpts.items(): + # deal with bottomup-style output + if isinstance(kpts[img_id][0][key], Sequence): + return kpts + num = len(persons) + kpts[img_id] = sorted(kpts[img_id], key=lambda x: x[key]) + for i in range(num - 1, 0, -1): + if kpts[img_id][i][key] == kpts[img_id][i - 1][key]: + del kpts[img_id][i] + + return kpts diff --git a/mmpose/evaluation/metrics/coco_wholebody_metric.py b/mmpose/evaluation/metrics/coco_wholebody_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..c5675f54c8fe793f3f5c1a30284dc2ab44b28ccd --- /dev/null +++ b/mmpose/evaluation/metrics/coco_wholebody_metric.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +from typing import Dict, Optional, Sequence + +import numpy as np +from mmengine.fileio import dump +from xtcocotools.cocoeval import COCOeval + +from mmpose.registry import METRICS +from .coco_metric import CocoMetric + + +@METRICS.register_module() +class CocoWholeBodyMetric(CocoMetric): + """COCO-WholeBody evaluation metric. + + Evaluate AR, AP, and mAP for COCO-WholeBody keypoint detection tasks. + Support COCO-WholeBody dataset. Please refer to + `COCO keypoint evaluation `__ + for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None + use_area (bool): Whether to use ``'area'`` message in the annotations. + If the ground truth annotations (e.g. CrowdPose, AIC) do not have + the field ``'area'``, please set ``use_area=False``. + Defaults to ``True`` + iou_type (str): The same parameter as `iouType` in + :class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or + ``'keypoints_crowd'`` (used in CrowdPose dataset). + Defaults to ``'keypoints'`` + score_mode (str): The mode to score the prediction results which + should be one of the following options: + + - ``'bbox'``: Take the score of bbox as the score of the + prediction results. + - ``'bbox_keypoint'``: Use keypoint score to rescore the + prediction results. + - ``'bbox_rle'``: Use rle_score to rescore the + prediction results. + + Defaults to ``'bbox_keypoint'` + keypoint_score_thr (float): The threshold of keypoint score. The + keypoints with score lower than it will not be included to + rescore the prediction results. Valid only when ``score_mode`` is + ``bbox_keypoint``. Defaults to ``0.2`` + nms_mode (str): The mode to perform Non-Maximum Suppression (NMS), + which should be one of the following options: + + - ``'oks_nms'``: Use Object Keypoint Similarity (OKS) to + perform NMS. + - ``'soft_oks_nms'``: Use Object Keypoint Similarity (OKS) + to perform soft NMS. + - ``'none'``: Do not perform NMS. Typically for bottomup mode + output. + + Defaults to ``'oks_nms'` + nms_thr (float): The Object Keypoint Similarity (OKS) threshold + used in NMS when ``nms_mode`` is ``'oks_nms'`` or + ``'soft_oks_nms'``. Will retain the prediction results with OKS + lower than ``nms_thr``. Defaults to ``0.9`` + format_only (bool): Whether only format the output results without + doing quantitative evaluation. This is designed for the need of + test submission when the ground truth annotations are absent. If + set to ``True``, ``outfile_prefix`` should specify the path to + store the output results. Defaults to ``False`` + outfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., ``'a/b/prefix'``. + If not specified, a temp file will be created. Defaults to ``None`` + **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric` + """ + default_prefix: Optional[str] = 'coco-wholebody' + body_num = 17 + foot_num = 6 + face_num = 68 + left_hand_num = 21 + right_hand_num = 21 + + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> str: + """Convert ground truth to coco format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. Each dict + contains the ground truth information about the data sample. + Required keys of the each `gt_dict` in `gt_dicts`: + - `img_id`: image id of the data sample + - `width`: original image width + - `height`: original image height + - `raw_ann_info`: the raw annotation information + Optional keys: + - `crowd_index`: measure the crowding level of an image, + defined in CrowdPose dataset + It is worth mentioning that, in order to compute `CocoMetric`, + there are some required keys in the `raw_ann_info`: + - `id`: the id to distinguish different annotations + - `image_id`: the image id of this annotation + - `category_id`: the category of the instance. + - `bbox`: the object bounding box + - `keypoints`: the keypoints cooridinates along with their + visibilities. Note that it need to be aligned + with the official COCO format, e.g., a list with length + N * 3, in which N is the number of keypoints. And each + triplet represent the [x, y, visible] of the keypoint. + - 'keypoints' + - `iscrowd`: indicating whether the annotation is a crowd. + It is useful when matching the detection results to + the ground truth. + There are some optional keys as well: + - `area`: it is necessary when `self.use_area` is `True` + - `num_keypoints`: it is necessary when `self.iou_type` + is set as `keypoints_crowd`. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + Returns: + str: The filename of the json file. + """ + image_infos = [] + annotations = [] + img_ids = [] + ann_ids = [] + + for gt_dict in gt_dicts: + # filter duplicate image_info + if gt_dict['img_id'] not in img_ids: + image_info = dict( + id=gt_dict['img_id'], + width=gt_dict['width'], + height=gt_dict['height'], + ) + if self.iou_type == 'keypoints_crowd': + image_info['crowdIndex'] = gt_dict['crowd_index'] + + image_infos.append(image_info) + img_ids.append(gt_dict['img_id']) + + # filter duplicate annotations + for ann in gt_dict['raw_ann_info']: + annotation = dict( + id=ann['id'], + image_id=ann['image_id'], + category_id=ann['category_id'], + bbox=ann['bbox'], + keypoints=ann['keypoints'], + foot_kpts=ann['foot_kpts'], + face_kpts=ann['face_kpts'], + lefthand_kpts=ann['lefthand_kpts'], + righthand_kpts=ann['righthand_kpts'], + iscrowd=ann['iscrowd'], + ) + if self.use_area: + assert 'area' in ann, \ + '`area` is required when `self.use_area` is `True`' + annotation['area'] = ann['area'] + + annotations.append(annotation) + ann_ids.append(ann['id']) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmpose CocoMetric.') + coco_json: dict = dict( + info=info, + images=image_infos, + categories=self.dataset_meta['CLASSES'], + licenses=None, + annotations=annotations, + ) + converted_json_path = f'{outfile_prefix}.gt.json' + dump(coco_json, converted_json_path, sort_keys=True, indent=4) + return converted_json_path + + def results2json(self, keypoints: Dict[int, list], + outfile_prefix: str) -> str: + """Dump the keypoint detection results to a COCO style json file. + + Args: + keypoints (Dict[int, list]): Keypoint detection results + of the dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json", + + Returns: + str: The json file name of keypoint results. + """ + # the results with category_id + cat_id = 1 + cat_results = [] + + cuts = np.cumsum([ + 0, self.body_num, self.foot_num, self.face_num, self.left_hand_num, + self.right_hand_num + ]) * 3 + + for _, img_kpts in keypoints.items(): + _keypoints = np.array( + [img_kpt['keypoints'] for img_kpt in img_kpts]) + num_keypoints = self.dataset_meta['num_keypoints'] + # collect all the person keypoints in current image + _keypoints = _keypoints.reshape(-1, num_keypoints * 3) + + result = [{ + 'image_id': img_kpt['img_id'], + 'category_id': cat_id, + 'keypoints': _keypoint[cuts[0]:cuts[1]].tolist(), + 'foot_kpts': _keypoint[cuts[1]:cuts[2]].tolist(), + 'face_kpts': _keypoint[cuts[2]:cuts[3]].tolist(), + 'lefthand_kpts': _keypoint[cuts[3]:cuts[4]].tolist(), + 'righthand_kpts': _keypoint[cuts[4]:cuts[5]].tolist(), + 'score': float(img_kpt['score']), + } for img_kpt, _keypoint in zip(img_kpts, _keypoints)] + + cat_results.extend(result) + + res_file = f'{outfile_prefix}.keypoints.json' + dump(cat_results, res_file, sort_keys=True, indent=4) + + def _do_python_keypoint_eval(self, outfile_prefix: str) -> list: + """Do keypoint evaluation using COCOAPI. + + Args: + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json", + + Returns: + list: a list of tuples. Each tuple contains the evaluation stats + name and corresponding stats value. + """ + res_file = f'{outfile_prefix}.keypoints.json' + coco_det = self.coco.loadRes(res_file) + sigmas = self.dataset_meta['sigmas'] + + cuts = np.cumsum([ + 0, self.body_num, self.foot_num, self.face_num, self.left_hand_num, + self.right_hand_num + ]) + + coco_eval = COCOeval( + self.coco, + coco_det, + 'keypoints_body', + sigmas[cuts[0]:cuts[1]], + use_area=True) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + 'keypoints_foot', + sigmas[cuts[1]:cuts[2]], + use_area=True) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + 'keypoints_face', + sigmas[cuts[2]:cuts[3]], + use_area=True) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + 'keypoints_lefthand', + sigmas[cuts[3]:cuts[4]], + use_area=True) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + 'keypoints_righthand', + sigmas[cuts[4]:cuts[5]], + use_area=True) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, coco_det, 'keypoints_wholebody', sigmas, use_area=True) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + stats_names = [ + 'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', + 'AR .75', 'AR (M)', 'AR (L)' + ] + + info_str = list(zip(stats_names, coco_eval.stats)) + + return info_str diff --git a/mmpose/evaluation/metrics/keypoint_2d_metrics.py b/mmpose/evaluation/metrics/keypoint_2d_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a63f1e5193872493fc32ec852a3a31dafa17cd --- /dev/null +++ b/mmpose/evaluation/metrics/keypoint_2d_metrics.py @@ -0,0 +1,910 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, Optional, Sequence, Union + +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpose.registry import METRICS +from ..functional import (keypoint_auc, keypoint_epe, keypoint_nme, + keypoint_pck_accuracy) + + +@METRICS.register_module() +class PCKAccuracy(BaseMetric): + """PCK accuracy evaluation metric. + Calculate the pose accuracy of Percentage of Correct Keypoints (PCK) for + each individual keypoint and the averaged accuracy across all keypoints. + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the person bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + Args: + thr(float): Threshold of PCK calculation. Default: 0.05. + norm_item (str | Sequence[str]): The item used for normalization. + Valid items include 'bbox', 'head', 'torso', which correspond + to 'PCK', 'PCKh' and 'tPCK' respectively. Default: ``'bbox'``. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + + Examples: + + >>> from mmpose.evaluation.metrics import PCKAccuracy + >>> import numpy as np + >>> from mmengine.structures import InstanceData + >>> num_keypoints = 15 + >>> keypoints = np.random.random((1, num_keypoints, 2)) * 10 + >>> gt_instances = InstanceData() + >>> gt_instances.keypoints = keypoints + >>> gt_instances.keypoints_visible = np.ones( + ... (1, num_keypoints, 1)).astype(bool) + >>> gt_instances.bboxes = np.random.random((1, 4)) * 20 + >>> pred_instances = InstanceData() + >>> pred_instances.keypoints = keypoints + >>> data_sample = { + ... 'gt_instances': gt_instances.to_dict(), + ... 'pred_instances': pred_instances.to_dict(), + ... } + >>> data_samples = [data_sample] + >>> data_batch = [{'inputs': None}] + >>> pck_metric = PCKAccuracy(thr=0.5, norm_item='bbox') + ...: UserWarning: The prefix is not set in metric class PCKAccuracy. + >>> pck_metric.process(data_batch, data_samples) + >>> pck_metric.evaluate(1) + 10/26 15:37:57 - mmengine - INFO - Evaluating PCKAccuracy (normalized by ``"bbox_size"``)... # noqa + {'PCK': 1.0} + + """ + + def __init__(self, + thr: float = 0.05, + norm_item: Union[str, Sequence[str]] = 'bbox', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.thr = thr + self.norm_item = norm_item if isinstance(norm_item, + (tuple, + list)) else [norm_item] + allow_normalized_items = ['bbox', 'head', 'torso'] + for item in self.norm_item: + if item not in allow_normalized_items: + raise KeyError( + f'The normalized item {item} is not supported by ' + f"{self.__class__.__name__}. Should be one of 'bbox', " + f"'head', 'torso', but got {item}.") + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. + + The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + # predicted keypoints coordinates, [1, K, D] + pred_coords = data_sample['pred_instances']['keypoints'] + # ground truth data_info + gt = data_sample['gt_instances'] + # ground truth keypoints coordinates, [1, K, D] + gt_coords = gt['keypoints'] + # ground truth keypoints_visible, [1, K, 1] + mask = gt['keypoints_visible'].astype(bool).reshape(1, -1) + + result = { + 'pred_coords': pred_coords, + 'gt_coords': gt_coords, + 'mask': mask, + } + + if 'bbox' in self.norm_item: + assert 'bboxes' in gt, 'The ground truth data info do not ' \ + 'have the expected normalized_item ``"bbox"``.' + # ground truth bboxes, [1, 4] + bbox_size_ = np.max(gt['bboxes'][0][2:] - gt['bboxes'][0][:2]) + bbox_size = np.array([bbox_size_, bbox_size_]).reshape(-1, 2) + result['bbox_size'] = bbox_size + + if 'head' in self.norm_item: + assert 'head_size' in gt, 'The ground truth data info do ' \ + 'not have the expected normalized_item ``"head_size"``.' + # ground truth bboxes + head_size_ = gt['head_size'] + head_size = np.array([head_size_, head_size_]).reshape(-1, 2) + result['head_size'] = head_size + + if 'torso' in self.norm_item: + # used in JhmdbDataset + torso_size_ = np.linalg.norm(gt_coords[0][4] - gt_coords[0][5]) + if torso_size_ < 1: + torso_size_ = np.linalg.norm(pred_coords[0][4] - + pred_coords[0][5]) + warnings.warn('Ground truth torso size < 1. ' + 'Use torso size from predicted ' + 'keypoint results instead.') + torso_size = np.array([torso_size_, + torso_size_]).reshape(-1, 2) + result['torso_size'] = torso_size + + self.results.append(result) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + The returned result dict may have the following keys: + - 'PCK': The pck accuracy normalized by `bbox_size`. + - 'PCKh': The pck accuracy normalized by `head_size`. + - 'tPCK': The pck accuracy normalized by `torso_size`. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + + metrics = dict() + if 'bbox' in self.norm_item: + norm_size_bbox = np.concatenate( + [result['bbox_size'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__} ' + f'(normalized by ``"bbox_size"``)...') + + _, pck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask, + self.thr, norm_size_bbox) + metrics['PCK'] = pck + + if 'head' in self.norm_item: + norm_size_head = np.concatenate( + [result['head_size'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__} ' + f'(normalized by ``"head_size"``)...') + + _, pckh, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask, + self.thr, norm_size_head) + metrics['PCKh'] = pckh + + if 'torso' in self.norm_item: + norm_size_torso = np.concatenate( + [result['torso_size'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__} ' + f'(normalized by ``"torso_size"``)...') + + _, tpck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask, + self.thr, norm_size_torso) + metrics['tPCK'] = tpck + + return metrics + + +@METRICS.register_module() +class MpiiPCKAccuracy(PCKAccuracy): + """PCKh accuracy evaluation metric for MPII dataset. + + Calculate the pose accuracy of Percentage of Correct Keypoints (PCK) for + each individual keypoint and the averaged accuracy across all keypoints. + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the person bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Args: + thr(float): Threshold of PCK calculation. Default: 0.05. + norm_item (str | Sequence[str]): The item used for normalization. + Valid items include 'bbox', 'head', 'torso', which correspond + to 'PCK', 'PCKh' and 'tPCK' respectively. Default: ``'head'``. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + + Examples: + + >>> from mmpose.evaluation.metrics import MpiiPCKAccuracy + >>> import numpy as np + >>> from mmengine.structures import InstanceData + >>> num_keypoints = 16 + >>> keypoints = np.random.random((1, num_keypoints, 2)) * 10 + >>> gt_instances = InstanceData() + >>> gt_instances.keypoints = keypoints + 1.0 + >>> gt_instances.keypoints_visible = np.ones( + ... (1, num_keypoints, 1)).astype(bool) + >>> gt_instances.head_size = np.random.random((1, 1)) * 10 + >>> pred_instances = InstanceData() + >>> pred_instances.keypoints = keypoints + >>> data_sample = { + ... 'gt_instances': gt_instances.to_dict(), + ... 'pred_instances': pred_instances.to_dict(), + ... } + >>> data_samples = [data_sample] + >>> data_batch = [{'inputs': None}] + >>> mpii_pck_metric = MpiiPCKAccuracy(thr=0.3, norm_item='head') + ... UserWarning: The prefix is not set in metric class MpiiPCKAccuracy. + >>> mpii_pck_metric.process(data_batch, data_samples) + >>> mpii_pck_metric.evaluate(1) + 10/26 17:43:39 - mmengine - INFO - Evaluating MpiiPCKAccuracy (normalized by ``"head_size"``)... # noqa + {'Head PCK': 100.0, 'Shoulder PCK': 100.0, 'Elbow PCK': 100.0, + Wrist PCK': 100.0, 'Hip PCK': 100.0, 'Knee PCK': 100.0, + 'Ankle PCK': 100.0, 'PCK': 100.0, 'PCK@0.1': 100.0} + """ + + def __init__(self, + thr: float = 0.5, + norm_item: Union[str, Sequence[str]] = 'head', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__( + thr=thr, + norm_item=norm_item, + collect_device=collect_device, + prefix=prefix) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + If `'head'` in `self.norm_item`, the returned results are the pck + accuracy normalized by `head_size`, which have the following keys: + - 'Head PCK': The PCK of head + - 'Shoulder PCK': The PCK of shoulder + - 'Elbow PCK': The PCK of elbow + - 'Wrist PCK': The PCK of wrist + - 'Hip PCK': The PCK of hip + - 'Knee PCK': The PCK of knee + - 'Ankle PCK': The PCK of ankle + - 'PCK': The mean PCK over all keypoints + - 'PCK@0.1': The mean PCK at threshold 0.1 + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + + # MPII uses matlab format, gt index is 1-based, + # convert 0-based index to 1-based index + pred_coords = pred_coords + 1.0 + + metrics = {} + if 'head' in self.norm_item: + norm_size_head = np.concatenate( + [result['head_size'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__} ' + f'(normalized by ``"head_size"``)...') + + pck_p, _, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask, + self.thr, norm_size_head) + + jnt_count = np.sum(mask, axis=0) + PCKh = 100. * pck_p + + rng = np.arange(0, 0.5 + 0.01, 0.01) + pckAll = np.zeros((len(rng), 16), dtype=np.float32) + + for r, threshold in enumerate(rng): + _pck, _, _ = keypoint_pck_accuracy(pred_coords, gt_coords, + mask, threshold, + norm_size_head) + pckAll[r, :] = 100. * _pck + + PCKh = np.ma.array(PCKh, mask=False) + PCKh.mask[6:8] = True + + jnt_count = np.ma.array(jnt_count, mask=False) + jnt_count.mask[6:8] = True + jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64) + + # dataset_joints_idx: + # head 9 + # lsho 13 rsho 12 + # lelb 14 relb 11 + # lwri 15 rwri 10 + # lhip 3 rhip 2 + # lkne 4 rkne 1 + # lank 5 rank 0 + stats = { + 'Head PCK': PCKh[9], + 'Shoulder PCK': 0.5 * (PCKh[13] + PCKh[12]), + 'Elbow PCK': 0.5 * (PCKh[14] + PCKh[11]), + 'Wrist PCK': 0.5 * (PCKh[15] + PCKh[10]), + 'Hip PCK': 0.5 * (PCKh[3] + PCKh[2]), + 'Knee PCK': 0.5 * (PCKh[4] + PCKh[1]), + 'Ankle PCK': 0.5 * (PCKh[5] + PCKh[0]), + 'PCK': np.sum(PCKh * jnt_ratio), + 'PCK@0.1': np.sum(pckAll[10, :] * jnt_ratio) + } + + for stats_name, stat in stats.items(): + metrics[stats_name] = stat + + return metrics + + +@METRICS.register_module() +class JhmdbPCKAccuracy(PCKAccuracy): + """PCK accuracy evaluation metric for Jhmdb dataset. + + Calculate the pose accuracy of Percentage of Correct Keypoints (PCK) for + each individual keypoint and the averaged accuracy across all keypoints. + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the person bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Args: + thr(float): Threshold of PCK calculation. Default: 0.05. + norm_item (str | Sequence[str]): The item used for normalization. + Valid items include 'bbox', 'head', 'torso', which correspond + to 'PCK', 'PCKh' and 'tPCK' respectively. Default: ``'bbox'``. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + + Examples: + + >>> from mmpose.evaluation.metrics import JhmdbPCKAccuracy + >>> import numpy as np + >>> from mmengine.structures import InstanceData + >>> num_keypoints = 15 + >>> keypoints = np.random.random((1, num_keypoints, 2)) * 10 + >>> gt_instances = InstanceData() + >>> gt_instances.keypoints = keypoints + >>> gt_instances.keypoints_visible = np.ones( + ... (1, num_keypoints, 1)).astype(bool) + >>> gt_instances.bboxes = np.random.random((1, 4)) * 20 + >>> gt_instances.head_size = np.random.random((1, 1)) * 10 + >>> pred_instances = InstanceData() + >>> pred_instances.keypoints = keypoints + >>> data_sample = { + ... 'gt_instances': gt_instances.to_dict(), + ... 'pred_instances': pred_instances.to_dict(), + ... } + >>> data_samples = [data_sample] + >>> data_batch = [{'inputs': None}] + >>> jhmdb_pck_metric = JhmdbPCKAccuracy(thr=0.2, norm_item=['bbox', 'torso']) + ... UserWarning: The prefix is not set in metric class JhmdbPCKAccuracy. + >>> jhmdb_pck_metric.process(data_batch, data_samples) + >>> jhmdb_pck_metric.evaluate(1) + 10/26 17:48:09 - mmengine - INFO - Evaluating JhmdbPCKAccuracy (normalized by ``"bbox_size"``)... # noqa + 10/26 17:48:09 - mmengine - INFO - Evaluating JhmdbPCKAccuracy (normalized by ``"torso_size"``)... # noqa + {'Head PCK': 1.0, 'Sho PCK': 1.0, 'Elb PCK': 1.0, 'Wri PCK': 1.0, + 'Hip PCK': 1.0, 'Knee PCK': 1.0, 'Ank PCK': 1.0, 'PCK': 1.0, + 'Head tPCK': 1.0, 'Sho tPCK': 1.0, 'Elb tPCK': 1.0, 'Wri tPCK': 1.0, + 'Hip tPCK': 1.0, 'Knee tPCK': 1.0, 'Ank tPCK': 1.0, 'tPCK': 1.0} + """ + + def __init__(self, + thr: float = 0.05, + norm_item: Union[str, Sequence[str]] = 'bbox', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__( + thr=thr, + norm_item=norm_item, + collect_device=collect_device, + prefix=prefix) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + If `'bbox'` in `self.norm_item`, the returned results are the pck + accuracy normalized by `bbox_size`, which have the following keys: + - 'Head PCK': The PCK of head + - 'Sho PCK': The PCK of shoulder + - 'Elb PCK': The PCK of elbow + - 'Wri PCK': The PCK of wrist + - 'Hip PCK': The PCK of hip + - 'Knee PCK': The PCK of knee + - 'Ank PCK': The PCK of ankle + - 'PCK': The mean PCK over all keypoints + If `'torso'` in `self.norm_item`, the returned results are the pck + accuracy normalized by `torso_size`, which have the following keys: + - 'Head tPCK': The PCK of head + - 'Sho tPCK': The PCK of shoulder + - 'Elb tPCK': The PCK of elbow + - 'Wri tPCK': The PCK of wrist + - 'Hip tPCK': The PCK of hip + - 'Knee tPCK': The PCK of knee + - 'Ank tPCK': The PCK of ankle + - 'tPCK': The mean PCK over all keypoints + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + + metrics = dict() + if 'bbox' in self.norm_item: + norm_size_bbox = np.concatenate( + [result['bbox_size'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__} ' + f'(normalized by ``"bbox_size"``)...') + + pck_p, pck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask, + self.thr, norm_size_bbox) + stats = { + 'Head PCK': pck_p[2], + 'Sho PCK': 0.5 * pck_p[3] + 0.5 * pck_p[4], + 'Elb PCK': 0.5 * pck_p[7] + 0.5 * pck_p[8], + 'Wri PCK': 0.5 * pck_p[11] + 0.5 * pck_p[12], + 'Hip PCK': 0.5 * pck_p[5] + 0.5 * pck_p[6], + 'Knee PCK': 0.5 * pck_p[9] + 0.5 * pck_p[10], + 'Ank PCK': 0.5 * pck_p[13] + 0.5 * pck_p[14], + 'PCK': pck + } + + for stats_name, stat in stats.items(): + metrics[stats_name] = stat + + if 'torso' in self.norm_item: + norm_size_torso = np.concatenate( + [result['torso_size'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__} ' + f'(normalized by ``"torso_size"``)...') + + pck_p, pck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask, + self.thr, norm_size_torso) + + stats = { + 'Head tPCK': pck_p[2], + 'Sho tPCK': 0.5 * pck_p[3] + 0.5 * pck_p[4], + 'Elb tPCK': 0.5 * pck_p[7] + 0.5 * pck_p[8], + 'Wri tPCK': 0.5 * pck_p[11] + 0.5 * pck_p[12], + 'Hip tPCK': 0.5 * pck_p[5] + 0.5 * pck_p[6], + 'Knee tPCK': 0.5 * pck_p[9] + 0.5 * pck_p[10], + 'Ank tPCK': 0.5 * pck_p[13] + 0.5 * pck_p[14], + 'tPCK': pck + } + + for stats_name, stat in stats.items(): + metrics[stats_name] = stat + + return metrics + + +@METRICS.register_module() +class AUC(BaseMetric): + """AUC evaluation metric. + + Calculate the Area Under Curve (AUC) of keypoint PCK accuracy. + + By altering the threshold percentage in the calculation of PCK accuracy, + AUC can be generated to further evaluate the pose estimation algorithms. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Args: + norm_factor (float): AUC normalization factor, Default: 30 (pixels). + num_thrs (int): number of thresholds to calculate auc. Default: 20. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + """ + + def __init__(self, + norm_factor: float = 30, + num_thrs: int = 20, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.norm_factor = norm_factor + self.num_thrs = num_thrs + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_sample (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + # predicted keypoints coordinates, [1, K, D] + pred_coords = data_sample['pred_instances']['keypoints'] + # ground truth data_info + gt = data_sample['gt_instances'] + # ground truth keypoints coordinates, [1, K, D] + gt_coords = gt['keypoints'] + # ground truth keypoints_visible, [1, K, 1] + mask = gt['keypoints_visible'].astype(bool).reshape(1, -1) + + result = { + 'pred_coords': pred_coords, + 'gt_coords': gt_coords, + 'mask': mask, + } + + self.results.append(result) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__}...') + + auc = keypoint_auc(pred_coords, gt_coords, mask, self.norm_factor, + self.num_thrs) + + metrics = dict() + metrics['AUC'] = auc + + return metrics + + +@METRICS.register_module() +class EPE(BaseMetric): + """EPE evaluation metric. + + Calculate the end-point error (EPE) of keypoints. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + """ + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + # predicted keypoints coordinates, [1, K, D] + pred_coords = data_sample['pred_instances']['keypoints'] + # ground truth data_info + gt = data_sample['gt_instances'] + # ground truth keypoints coordinates, [1, K, D] + gt_coords = gt['keypoints'] + # ground truth keypoints_visible, [1, K, 1] + mask = gt['keypoints_visible'].astype(bool).reshape(1, -1) + + result = { + 'pred_coords': pred_coords, + 'gt_coords': gt_coords, + 'mask': mask, + } + + self.results.append(result) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__}...') + + epe = keypoint_epe(pred_coords, gt_coords, mask) + + metrics = dict() + metrics['EPE'] = epe + + return metrics + + +@METRICS.register_module() +class NME(BaseMetric): + """NME evaluation metric. + + Calculate the normalized mean error (NME) of keypoints. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Args: + norm_mode (str): The normalization mode. There are two valid modes: + `'use_norm_item'` and `'keypoint_distance'`. + When set as `'use_norm_item'`, should specify the argument + `norm_item`, which represents the item in the datainfo that + will be used as the normalization factor. + When set as `'keypoint_distance'`, should specify the argument + `keypoint_indices` that are used to calculate the keypoint + distance as the normalization factor. + norm_item (str, optional): The item used as the normalization factor. + For example, `'bbox_size'` in `'AFLWDataset'`. Only valid when + ``norm_mode`` is ``use_norm_item``. + Default: ``None``. + keypoint_indices (Sequence[int], optional): The keypoint indices used + to calculate the keypoint distance as the normalization factor. + Only valid when ``norm_mode`` is ``keypoint_distance``. + If set as None, will use the default ``keypoint_indices`` in + `DEFAULT_KEYPOINT_INDICES` for specific datasets, else use the + given ``keypoint_indices`` of the dataset. Default: ``None``. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + """ + + DEFAULT_KEYPOINT_INDICES = { + # horse10: corresponding to `nose` and `eye` keypoints + 'horse10': [0, 1], + # 300w: corresponding to `right-most` and `left-most` eye keypoints + '300w': [36, 45], + # coco_wholebody_face corresponding to `right-most` and `left-most` + # eye keypoints + 'coco_wholebody_face': [36, 45], + # cofw: corresponding to `right-most` and `left-most` eye keypoints + 'cofw': [8, 9], + # wflw: corresponding to `right-most` and `left-most` eye keypoints + 'wflw': [60, 72], + } + + def __init__(self, + norm_mode: str, + norm_item: Optional[str] = None, + keypoint_indices: Optional[Sequence[int]] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + allowed_norm_modes = ['use_norm_item', 'keypoint_distance'] + if norm_mode not in allowed_norm_modes: + raise KeyError("`norm_mode` should be 'use_norm_item' or " + f"'keypoint_distance', but got {norm_mode}.") + + self.norm_mode = norm_mode + if self.norm_mode == 'use_norm_item': + if not norm_item: + raise KeyError('`norm_mode` is set to `"use_norm_item"`, ' + 'please specify the `norm_item` in the ' + 'datainfo used as the normalization factor.') + self.norm_item = norm_item + self.keypoint_indices = keypoint_indices + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + # predicted keypoints coordinates, [1, K, D] + pred_coords = data_sample['pred_instances']['keypoints'] + # ground truth data_info + gt = data_sample['gt_instances'] + # ground truth keypoints coordinates, [1, K, D] + gt_coords = gt['keypoints'] + # ground truth keypoints_visible, [1, K, 1] + mask = gt['keypoints_visible'].astype(bool).reshape(1, -1) + + result = { + 'pred_coords': pred_coords, + 'gt_coords': gt_coords, + 'mask': mask, + } + + if self.norm_item: + if self.norm_item == 'bbox_size': + assert 'bboxes' in gt, 'The ground truth data info do ' \ + 'not have the item ``bboxes`` for expected ' \ + 'normalized_item ``"bbox_size"``.' + # ground truth bboxes, [1, 4] + bbox_size = np.max(gt['bboxes'][0][2:] - + gt['bboxes'][0][:2]) + result['bbox_size'] = np.array([bbox_size]).reshape(-1, 1) + else: + assert self.norm_item in gt, f'The ground truth data ' \ + f'info do not have the expected normalized factor ' \ + f'"{self.norm_item}"' + # ground truth norm_item + result[self.norm_item] = np.array( + gt[self.norm_item]).reshape([-1, 1]) + + self.results.append(result) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + + logger.info(f'Evaluating {self.__class__.__name__}...') + metrics = dict() + + if self.norm_mode == 'use_norm_item': + normalize_factor_ = np.concatenate( + [result[self.norm_item] for result in results]) + # normalize_factor: [N, 2] + normalize_factor = np.tile(normalize_factor_, [1, 2]) + nme = keypoint_nme(pred_coords, gt_coords, mask, normalize_factor) + metrics['NME'] = nme + + else: + if self.keypoint_indices is None: + # use default keypoint_indices in some datasets + dataset_name = self.dataset_meta['dataset_name'] + if dataset_name not in self.DEFAULT_KEYPOINT_INDICES: + raise KeyError( + '`norm_mode` is set to `keypoint_distance`, and the ' + 'keypoint_indices is set to None, can not find the ' + 'keypoint_indices in `DEFAULT_KEYPOINT_INDICES`, ' + 'please specify `keypoint_indices` appropriately.') + self.keypoint_indices = self.DEFAULT_KEYPOINT_INDICES[ + dataset_name] + else: + assert len(self.keypoint_indices) == 2, 'The keypoint '\ + 'indices used for normalization should be a pair.' + keypoint_id2name = self.dataset_meta['keypoint_id2name'] + dataset_name = self.dataset_meta['dataset_name'] + for idx in self.keypoint_indices: + assert idx in keypoint_id2name, f'The {dataset_name} '\ + f'dataset does not contain the required '\ + f'{idx}-th keypoint.' + # normalize_factor: [N, 2] + normalize_factor = self._get_normalize_factor(gt_coords=gt_coords) + nme = keypoint_nme(pred_coords, gt_coords, mask, normalize_factor) + metrics['NME'] = nme + + return metrics + + def _get_normalize_factor(self, gt_coords: np.ndarray) -> np.ndarray: + """Get the normalize factor. generally inter-ocular distance measured + as the Euclidean distance between the outer corners of the eyes is + used. + + Args: + gt_coords (np.ndarray[N, K, 2]): Groundtruth keypoint coordinates. + + Returns: + np.ndarray[N, 2]: normalized factor + """ + idx1, idx2 = self.keypoint_indices + + interocular = np.linalg.norm( + gt_coords[:, idx1, :] - gt_coords[:, idx2, :], + axis=1, + keepdims=True) + + return np.tile(interocular, [1, 2]) diff --git a/mmpose/evaluation/metrics/keypoint_partition_metric.py b/mmpose/evaluation/metrics/keypoint_partition_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..fb30eca0d57f68e94cba93deec1f63bd333468aa --- /dev/null +++ b/mmpose/evaluation/metrics/keypoint_partition_metric.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import OrderedDict +from copy import deepcopy +from typing import Sequence + +import numpy as np +from mmengine.evaluator import BaseMetric + +from mmpose.registry import METRICS + + +@METRICS.register_module() +class KeypointPartitionMetric(BaseMetric): + """Wrapper metric for evaluating pose metric on user-defined body parts. + + Sometimes one may be interested in the performance of a pose model on + certain body parts rather than on all the keypoints. For example, + ``CocoWholeBodyMetric`` evaluates coco metric on body, foot, face, + lefthand and righthand. However, ``CocoWholeBodyMetric`` cannot be + applied to arbitrary custom datasets. This wrapper metric solves this + problem. + + Supported metrics: + ``CocoMetric`` Note 1: all keypoint ground truth should be stored in + `keypoints` not other data fields. Note 2: `ann_file` is not + supported, it will be ignored. Note 3: `score_mode` other than + 'bbox' may produce results different from the + ``CocoWholebodyMetric``. Note 4: `nms_mode` other than 'none' may + produce results different from the ``CocoWholebodyMetric``. + ``PCKAccuracy`` Note 1: data fields required by ``PCKAccuracy`` should + be provided, such as bbox, head_size, etc. Note 2: In terms of + 'torso', since it is specifically designed for ``JhmdbDataset``, it is + not recommended to use it for other datasets. + ``AUC`` supported without limitations. + ``EPE`` supported without limitations. + ``NME`` only `norm_mode` = 'use_norm_item' is supported, + 'keypoint_distance' is incompatible with ``KeypointPartitionMetric``. + + Incompatible metrics: + The following metrics are dataset specific metrics: + ``CocoWholeBodyMetric`` + ``MpiiPCKAccuracy`` + ``JhmdbPCKAccuracy`` + ``PoseTrack18Metric`` + Keypoint partitioning is included in these metrics. + + Args: + metric (dict): arguments to instantiate a metric, please refer to the + arguments required by the metric of your choice. + partitions (dict): definition of body partitions. For example, if we + have 10 keypoints in total, the first 7 keypoints belong to body + and the last 3 keypoints belong to foot, this field can be like + this: + dict( + body=[0, 1, 2, 3, 4, 5, 6], + foot=[7, 8, 9], + all=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + where the numbers are the indices of keypoints and they can be + discontinuous. + """ + + def __init__( + self, + metric: dict, + partitions: dict, + ) -> None: + super().__init__() + # check metric type + supported_metric_types = [ + 'CocoMetric', 'PCKAccuracy', 'AUC', 'EPE', 'NME' + ] + if metric['type'] not in supported_metric_types: + raise ValueError( + 'Metrics supported by KeypointPartitionMetric are CocoMetric, ' + 'PCKAccuracy, AUC, EPE and NME, ' + f"but got {metric['type']}") + + # check CocoMetric arguments + if metric['type'] == 'CocoMetric': + if 'ann_file' in metric: + warnings.warn( + 'KeypointPartitionMetric does not support the ann_file ' + 'argument of CocoMetric, this argument will be ignored.') + metric['ann_file'] = None + score_mode = metric.get('score_mode', 'bbox_keypoint') + if score_mode != 'bbox': + warnings.warn( + 'When using KeypointPartitionMetric with CocoMetric, ' + "if score_mode is not 'bbox', pose scores will be " + "calculated part by part rather than by 'wholebody'. " + 'Therefore, this may produce results different from the ' + 'CocoWholebodyMetric.') + nms_mode = metric.get('nms_mode', 'oks_nms') + if nms_mode != 'none': + warnings.warn( + 'When using KeypointPartitionMetric with CocoMetric, ' + 'oks_nms and soft_oks_nms will be calculated part by part ' + "rather than by 'wholebody'. Therefore, this may produce " + 'results different from the CocoWholebodyMetric.') + + # check PCKAccuracy arguments + if metric['type'] == 'PCKAccuracy': + norm_item = metric.get('norm_item', 'bbox') + if norm_item == 'torso' or 'torso' in norm_item: + warnings.warn( + 'norm_item torso is used in JhmdbDataset, it may not be ' + 'compatible with other datasets, use at your own risk.') + + # check NME arguments + if metric['type'] == 'NME': + assert 'norm_mode' in metric, \ + 'Missing norm_mode required by the NME metric.' + if metric['norm_mode'] != 'use_norm_item': + raise ValueError( + "NME norm_mode 'keypoint_distance' is incompatible with " + 'KeypointPartitionMetric.') + + # check partitions + assert len(partitions) > 0, 'There should be at least one partition.' + for partition_name, partition in partitions.items(): + assert isinstance(partition, Sequence), \ + 'Each partition should be a sequence.' + assert len(partition) > 0, \ + 'Each partition should have at least one element.' + self.partitions = partitions + + # instantiate metrics for each partition + self.metrics = {} + for partition_name in partitions.keys(): + _metric = deepcopy(metric) + if 'outfile_prefix' in _metric: + _metric['outfile_prefix'] = _metric[ + 'outfile_prefix'] + '.' + partition_name + self.metrics[partition_name] = METRICS.build(_metric) + + @BaseMetric.dataset_meta.setter + def dataset_meta(self, dataset_meta: dict) -> None: + """Set the dataset meta info to the metric.""" + self._dataset_meta = dataset_meta + # sigmas required by coco metric have to be split as well + for partition_name, keypoint_ids in self.partitions.items(): + _dataset_meta = deepcopy(dataset_meta) + _dataset_meta['num_keypoints'] = len(keypoint_ids) + _dataset_meta['sigmas'] = _dataset_meta['sigmas'][keypoint_ids] + self.metrics[partition_name].dataset_meta = _dataset_meta + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Split data samples by partitions, then call metric.process part by + part.""" + parted_data_samples = { + partition_name: [] + for partition_name in self.partitions.keys() + } + for data_sample in data_samples: + for partition_name, keypoint_ids in self.partitions.items(): + _data_sample = deepcopy(data_sample) + if 'keypoint_scores' in _data_sample['pred_instances']: + _data_sample['pred_instances'][ + 'keypoint_scores'] = _data_sample['pred_instances'][ + 'keypoint_scores'][:, keypoint_ids] + _data_sample['pred_instances']['keypoints'] = _data_sample[ + 'pred_instances']['keypoints'][:, keypoint_ids] + _data_sample['gt_instances']['keypoints'] = _data_sample[ + 'gt_instances']['keypoints'][:, keypoint_ids] + _data_sample['gt_instances'][ + 'keypoints_visible'] = _data_sample['gt_instances'][ + 'keypoints_visible'][:, keypoint_ids] + + # for coco metric + if 'raw_ann_info' in _data_sample: + raw_ann_info = _data_sample['raw_ann_info'] + anns = raw_ann_info if isinstance( + raw_ann_info, list) else [raw_ann_info] + for ann in anns: + if 'keypoints' in ann: + keypoints = np.array(ann['keypoints']).reshape( + -1, 3) + keypoints = keypoints[keypoint_ids] + num_keypoints = np.sum(keypoints[:, 2] > 0) + ann['keypoints'] = keypoints.flatten().tolist() + ann['num_keypoints'] = num_keypoints + + parted_data_samples[partition_name].append(_data_sample) + + for partition_name, metric in self.metrics.items(): + metric.process(data_batch, parted_data_samples[partition_name]) + + def compute_metrics(self, results: list) -> dict: + pass + + def evaluate(self, size: int) -> dict: + """Run evaluation for each partition.""" + eval_results = OrderedDict() + for partition_name, metric in self.metrics.items(): + _eval_results = metric.evaluate(size) + for key in list(_eval_results.keys()): + new_key = partition_name + '/' + key + _eval_results[new_key] = _eval_results.pop(key) + eval_results.update(_eval_results) + return eval_results diff --git a/mmpose/evaluation/metrics/posetrack18_metric.py b/mmpose/evaluation/metrics/posetrack18_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..86f801455a62467aaf45722210a6018c95b0bdd4 --- /dev/null +++ b/mmpose/evaluation/metrics/posetrack18_metric.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from typing import Dict, List, Optional + +import numpy as np +from mmengine.fileio import dump, load +from mmengine.logging import MMLogger + +from mmpose.registry import METRICS +from .coco_metric import CocoMetric + +try: + from poseval import eval_helpers + from poseval.evaluateAP import evaluateAP + has_poseval = True +except (ImportError, ModuleNotFoundError): + has_poseval = False + + +@METRICS.register_module() +class PoseTrack18Metric(CocoMetric): + """PoseTrack18 evaluation metric. + + Evaluate AP, and mAP for keypoint detection tasks. + Support PoseTrack18 (video) dataset. Please refer to + ``__ + for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None + score_mode (str): The mode to score the prediction results which + should be one of the following options: + + - ``'bbox'``: Take the score of bbox as the score of the + prediction results. + - ``'bbox_keypoint'``: Use keypoint score to rescore the + prediction results. + + Defaults to ``'bbox_keypoint'` + keypoint_score_thr (float): The threshold of keypoint score. The + keypoints with score lower than it will not be included to + rescore the prediction results. Valid only when ``score_mode`` is + ``bbox_keypoint``. Defaults to ``0.2`` + nms_mode (str): The mode to perform Non-Maximum Suppression (NMS), + which should be one of the following options: + + - ``'oks_nms'``: Use Object Keypoint Similarity (OKS) to + perform NMS. + - ``'soft_oks_nms'``: Use Object Keypoint Similarity (OKS) + to perform soft NMS. + - ``'none'``: Do not perform NMS. Typically for bottomup mode + output. + + Defaults to ``'oks_nms'` + nms_thr (float): The Object Keypoint Similarity (OKS) threshold + used in NMS when ``nms_mode`` is ``'oks_nms'`` or + ``'soft_oks_nms'``. Will retain the prediction results with OKS + lower than ``nms_thr``. Defaults to ``0.9`` + format_only (bool): Whether only format the output results without + doing quantitative evaluation. This is designed for the need of + test submission when the ground truth annotations are absent. If + set to ``True``, ``outfile_prefix`` should specify the path to + store the output results. Defaults to ``False`` + outfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., ``'a/b/prefix'``. + If not specified, a temp file will be created. Defaults to ``None`` + **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric` + """ + default_prefix: Optional[str] = 'posetrack18' + + def __init__(self, + ann_file: Optional[str] = None, + score_mode: str = 'bbox_keypoint', + keypoint_score_thr: float = 0.2, + nms_mode: str = 'oks_nms', + nms_thr: float = 0.9, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + # raise an error to avoid long time running without getting results + if not has_poseval: + raise ImportError('Please install ``poseval`` package for ' + 'evaluation on PoseTrack dataset ' + '(see `requirements/optional.txt`)') + super().__init__( + ann_file=ann_file, + score_mode=score_mode, + keypoint_score_thr=keypoint_score_thr, + nms_mode=nms_mode, + nms_thr=nms_thr, + format_only=format_only, + outfile_prefix=outfile_prefix, + collect_device=collect_device, + prefix=prefix) + + def results2json(self, keypoints: Dict[int, list], + outfile_prefix: str) -> str: + """Dump the keypoint detection results into a json file. + + Args: + keypoints (Dict[int, list]): Keypoint detection results + of the dataset. + outfile_prefix (str): The filename prefix of the json files. + If the prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json". + + Returns: + str: The json file name of keypoint results. + """ + categories = [] + + cat = {} + cat['supercategory'] = 'person' + cat['id'] = 1 + cat['name'] = 'person' + cat['keypoints'] = [ + 'nose', 'head_bottom', 'head_top', 'left_ear', 'right_ear', + 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', + 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', + 'right_knee', 'left_ankle', 'right_ankle' + ] + cat['skeleton'] = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], + [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], + [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], + [4, 6], [5, 7]] + categories.append(cat) + + # path of directory for official gt files + gt_folder = osp.join( + osp.dirname(self.ann_file), + osp.splitext(self.ann_file.split('_')[-1])[0]) + # the json file for each video sequence + json_files = [ + pos for pos in os.listdir(gt_folder) if pos.endswith('.json') + ] + + for json_file in json_files: + gt = load(osp.join(gt_folder, json_file)) + annotations = [] + images = [] + + for image in gt['images']: + img = {} + img['id'] = image['id'] + img['file_name'] = image['file_name'] + images.append(img) + + img_kpts = keypoints[img['id']] + + for track_id, img_kpt in enumerate(img_kpts): + ann = {} + ann['image_id'] = img_kpt['img_id'] + ann['keypoints'] = np.array( + img_kpt['keypoints']).reshape(-1).tolist() + ann['scores'] = np.array(ann['keypoints']).reshape( + [-1, 3])[:, 2].tolist() + ann['score'] = float(img_kpt['score']) + ann['track_id'] = track_id + annotations.append(ann) + + pred_file = osp.join(osp.dirname(outfile_prefix), json_file) + info = {} + info['images'] = images + info['categories'] = categories + info['annotations'] = annotations + + dump(info, pred_file, sort_keys=True, indent=4) + + def _do_python_keypoint_eval(self, outfile_prefix: str) -> List[tuple]: + """Do keypoint evaluation using `poseval` package. + + Args: + outfile_prefix (str): The filename prefix of the json files. + If the prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json". + + Returns: + list: a list of tuples. Each tuple contains the evaluation stats + name and corresponding stats value. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # path of directory for official gt files + # 'xxx/posetrack18_train.json' -> 'xxx/train/' + gt_folder = osp.join( + osp.dirname(self.ann_file), + osp.splitext(self.ann_file.split('_')[-1])[0]) + pred_folder = osp.dirname(outfile_prefix) + + argv = ['', gt_folder + '/', pred_folder + '/'] + + logger.info('Loading data') + gtFramesAll, prFramesAll = eval_helpers.load_data_dir(argv) + + logger.info(f'# gt frames : {len(gtFramesAll)}') + logger.info(f'# pred frames: {len(prFramesAll)}') + + # evaluate per-frame multi-person pose estimation (AP) + # compute AP + logger.info('Evaluation of per-frame multi-person pose estimation') + apAll, _, _ = evaluateAP(gtFramesAll, prFramesAll, None, False, False) + + # print AP + logger.info('Average Precision (AP) metric:') + eval_helpers.printTable(apAll) + + stats = eval_helpers.getCum(apAll) + + stats_names = [ + 'Head AP', 'Shou AP', 'Elb AP', 'Wri AP', 'Hip AP', 'Knee AP', + 'Ankl AP', 'AP' + ] + + info_str = list(zip(stats_names, stats)) + + return info_str diff --git a/mmpose/models/__init__.py b/mmpose/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e236f9928601328c6cd42817d842b034c4b9b13 --- /dev/null +++ b/mmpose/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa +from .builder import (BACKBONES, HEADS, LOSSES, NECKS, build_backbone, + build_head, build_loss, build_neck, build_pose_estimator, + build_posenet) +from .data_preprocessors import * # noqa +from .heads import * # noqa +from .losses import * # noqa +from .necks import * # noqa +from .pose_estimators import * # noqa + +__all__ = [ + 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'build_backbone', 'build_head', + 'build_loss', 'build_posenet', 'build_neck', 'build_pose_estimator' +] diff --git a/mmpose/models/__pycache__/__init__.cpython-38.pyc b/mmpose/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd830f81f73c0252b056888761e237100e6b4b9 Binary files /dev/null and b/mmpose/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/__pycache__/builder.cpython-38.pyc b/mmpose/models/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb117b14e01105b2cfa30261f6a07bbddf37255d Binary files /dev/null and b/mmpose/models/__pycache__/builder.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb2498560a34afca0f3218eb9fb3d9a5dce04f33 --- /dev/null +++ b/mmpose/models/backbones/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .alexnet import AlexNet +from .cpm import CPM +from .hourglass import HourglassNet +from .hourglass_ae import HourglassAENet +from .hrformer import HRFormer +from .hrnet import HRNet +from .litehrnet import LiteHRNet +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mspn import MSPN +from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2 +from .regnet import RegNet +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1d +from .resnext import ResNeXt +from .rsn import RSN +from .scnet import SCNet +from .seresnet import SEResNet +from .seresnext import SEResNeXt +from .shufflenet_v1 import ShuffleNetV1 +from .shufflenet_v2 import ShuffleNetV2 +from .swin import SwinTransformer +from .tcn import TCN +from .v2v_net import V2VNet +from .vgg import VGG +from .vipnas_mbv3 import ViPNAS_MobileNetV3 +from .vipnas_resnet import ViPNAS_ResNet + +__all__ = [ + 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2', + 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet', + 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN', + 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3', + 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer', + 'PyramidVisionTransformerV2', 'SwinTransformer' +] diff --git a/mmpose/models/backbones/__pycache__/__init__.cpython-38.pyc b/mmpose/models/backbones/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e4cfedbe68164a189ed5944aac624262bdce8e Binary files /dev/null and b/mmpose/models/backbones/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/alexnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/alexnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02c57e557c0134d15251f5ac12b55f77d68b4df3 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/alexnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/base_backbone.cpython-38.pyc b/mmpose/models/backbones/__pycache__/base_backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094c6ee895026f56fbcc2e7bcbcf6321d96abe0e Binary files /dev/null and b/mmpose/models/backbones/__pycache__/base_backbone.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/cpm.cpython-38.pyc b/mmpose/models/backbones/__pycache__/cpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb8562d4df605418d4a928315f1b028dfa52e7df Binary files /dev/null and b/mmpose/models/backbones/__pycache__/cpm.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/hourglass.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hourglass.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80f4d651cc22b66f19c213535969aca9d43aea9d Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hourglass.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/hourglass_ae.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hourglass_ae.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..930671a1781da7dfa946dfd89092c2a6a614e18c Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hourglass_ae.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/hrformer.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hrformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0454fdd8b97ac049f6af1f4f489494e254e8a74 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hrformer.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/hrnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d6cf2922bcf62743f8301108b7336cb7ed76296 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hrnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/litehrnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/litehrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b221bf8e66196f8da93b0ce99eb01ea1aecccf67 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/litehrnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc b/mmpose/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a19238d7c874b35916c03c0828c79d8ec8c6e54 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc b/mmpose/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6efce920167c3aec9a4770240ac633f2dd6de743 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/mspn.cpython-38.pyc b/mmpose/models/backbones/__pycache__/mspn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57485f1119da1c1d2d7117b2c20f5a3180d8bd97 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/mspn.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/pvt.cpython-38.pyc b/mmpose/models/backbones/__pycache__/pvt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44f378804eacefbf068b182bcfb0ffdb57f77b1c Binary files /dev/null and b/mmpose/models/backbones/__pycache__/pvt.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/regnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/regnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a8b3a935e4a4fb583623f0d661a3302a44e8f0c Binary files /dev/null and b/mmpose/models/backbones/__pycache__/regnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/resnest.cpython-38.pyc b/mmpose/models/backbones/__pycache__/resnest.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7744c3dc0b1cbfffa7eb70f6ac41624fad87c33 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/resnest.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/resnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53c46e3a4d63845f435050e10ea36083680bcff0 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/resnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/resnext.cpython-38.pyc b/mmpose/models/backbones/__pycache__/resnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d12fe070429012319a668ed14ea6ce8e8bb10f24 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/resnext.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/rsn.cpython-38.pyc b/mmpose/models/backbones/__pycache__/rsn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e89ef05de9f755ede5f823d3656dc21e3f433db1 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/rsn.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/scnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/scnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd163eea9dea0ba2e629a5dca6a5e35685ff3b52 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/scnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/seresnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/seresnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..225ff7dc153b2ceced76869a66954e4149b69e62 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/seresnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/seresnext.cpython-38.pyc b/mmpose/models/backbones/__pycache__/seresnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4d263daf00c53f9bc9bfae8c1717c6da5f0090b Binary files /dev/null and b/mmpose/models/backbones/__pycache__/seresnext.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc b/mmpose/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ddc089fa4190b165f2ba0429a13a453bd438621 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc b/mmpose/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6818af553035028af47fe2142ea04a947215a05a Binary files /dev/null and b/mmpose/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/swin.cpython-38.pyc b/mmpose/models/backbones/__pycache__/swin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b29f78b9cceecb95a90606481d4712c18e2d1b5 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/swin.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/tcn.cpython-38.pyc b/mmpose/models/backbones/__pycache__/tcn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58da95f1730f1f849bb082661ebdb905c58894cc Binary files /dev/null and b/mmpose/models/backbones/__pycache__/tcn.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/v2v_net.cpython-38.pyc b/mmpose/models/backbones/__pycache__/v2v_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60dd9065bab22797b21fe6dd1e8e9b488f52adea Binary files /dev/null and b/mmpose/models/backbones/__pycache__/v2v_net.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/vgg.cpython-38.pyc b/mmpose/models/backbones/__pycache__/vgg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bd7809dd77d58bf6af1a9223f4d5e0ca80ed9c1 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/vgg.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/vipnas_mbv3.cpython-38.pyc b/mmpose/models/backbones/__pycache__/vipnas_mbv3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bf2f571b1c9317f5830ac3872621686222c6bd9 Binary files /dev/null and b/mmpose/models/backbones/__pycache__/vipnas_mbv3.cpython-38.pyc differ diff --git a/mmpose/models/backbones/__pycache__/vipnas_resnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/vipnas_resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c6851030a84692aa8c5fb44b00472a4ca499e4c Binary files /dev/null and b/mmpose/models/backbones/__pycache__/vipnas_resnet.cpython-38.pyc differ diff --git a/mmpose/models/backbones/alexnet.py b/mmpose/models/backbones/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2262658f4718a079b2effc276282be4d39fbe6ad --- /dev/null +++ b/mmpose/models/backbones/alexnet.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class AlexNet(BaseBackbone): + """`AlexNet `__ backbone. + + The input for AlexNet is a 224x224 RGB image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, num_classes=-1, init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return (x, ) diff --git a/mmpose/models/backbones/base_backbone.py b/mmpose/models/backbones/base_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6094b4e831f992b052e4db206022f489a7f729b3 --- /dev/null +++ b/mmpose/models/backbones/base_backbone.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmengine.model import BaseModule + + +class BaseBackbone(BaseModule, metaclass=ABCMeta): + """Base backbone. + + This class defines the basic functions of a backbone. Any backbone that + inherits this class should at least define its own `forward` function. + """ + + @abstractmethod + def forward(self, x): + """Forward function. + + Args: + x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of + torch.Tensor, containing input data for forward computation. + """ + + def train(self, mode=True): + """Set module status before forward computation. + + Args: + mode (bool): Whether it is train_mode or test_mode + """ + super(BaseBackbone, self).train(mode) diff --git a/mmpose/models/backbones/cpm.py b/mmpose/models/backbones/cpm.py new file mode 100644 index 0000000000000000000000000000000000000000..256769c43a4d7b9d0cdd40fb6de19a90727012e8 --- /dev/null +++ b/mmpose/models/backbones/cpm.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class CpmBlock(BaseModule): + """CpmBlock for Convolutional Pose Machine. + + Args: + in_channels (int): Input channels of this block. + channels (list): Output channels of each conv module. + kernels (list): Kernel sizes of each conv module. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + channels=(128, 128, 128), + kernels=(11, 11, 11), + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert len(channels) == len(kernels) + layers = [] + for i in range(len(channels)): + if i == 0: + input_channels = in_channels + else: + input_channels = channels[i - 1] + layers.append( + ConvModule( + input_channels, + channels[i], + kernels[i], + padding=(kernels[i] - 1) // 2, + norm_cfg=norm_cfg)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + """Model forward function.""" + out = self.model(x) + return out + + +@MODELS.register_module() +class CPM(BaseBackbone): + """CPM backbone. + + Convolutional Pose Machines. + More details can be found in the `paper + `__ . + + Args: + in_channels (int): The input channels of the CPM. + out_channels (int): The output channels of the CPM. + feat_channels (int): Feature channel of each CPM stage. + middle_channels (int): Feature channel of conv after the middle stage. + num_stages (int): Number of stages. + norm_cfg (dict): Dictionary to construct and config norm layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import CPM + >>> import torch + >>> self = CPM(3, 17) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 368, 368) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + (1, 17, 46, 46) + """ + + def __init__( + self, + in_channels, + out_channels, + feat_channels=128, + middle_channels=32, + num_stages=6, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ], + ): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + + assert in_channels == 3 + + self.num_stages = num_stages + assert self.num_stages >= 1 + + self.stem = nn.Sequential( + ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg), + ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg), + ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg), + ConvModule(512, out_channels, 1, padding=0, act_cfg=None)) + + self.middle = nn.Sequential( + ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + self.cpm_stages = nn.ModuleList([ + CpmBlock( + middle_channels + out_channels, + channels=[feat_channels, feat_channels, feat_channels], + kernels=[11, 11, 11], + norm_cfg=norm_cfg) for _ in range(num_stages - 1) + ]) + + self.middle_conv = nn.ModuleList([ + nn.Sequential( + ConvModule( + 128, middle_channels, 5, padding=2, norm_cfg=norm_cfg)) + for _ in range(num_stages - 1) + ]) + + self.out_convs = nn.ModuleList([ + nn.Sequential( + ConvModule( + feat_channels, + feat_channels, + 1, + padding=0, + norm_cfg=norm_cfg), + ConvModule(feat_channels, out_channels, 1, act_cfg=None)) + for _ in range(num_stages - 1) + ]) + + def forward(self, x): + """Model forward function.""" + stage1_out = self.stem(x) + middle_out = self.middle(x) + out_feats = [] + + out_feats.append(stage1_out) + + for ind in range(self.num_stages - 1): + single_stage = self.cpm_stages[ind] + out_conv = self.out_convs[ind] + + inp_feat = torch.cat( + [out_feats[-1], self.middle_conv[ind](middle_out)], 1) + cpm_feat = single_stage(inp_feat) + out_feat = out_conv(cpm_feat) + out_feats.append(out_feat) + + return out_feats diff --git a/mmpose/models/backbones/hourglass.py b/mmpose/models/backbones/hourglass.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc8d6d328da5b63094015351cc10084cda46da0 --- /dev/null +++ b/mmpose/models/backbones/hourglass.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .resnet import BasicBlock, ResLayer + + +class HourglassModule(BaseModule): + """Hourglass Module for HourglassNet backbone. + + Generate module recursively and use BasicBlock as the base unit. + + Args: + depth (int): Depth of current HourglassModule. + stage_channels (list[int]): Feature channels of sub-modules in current + and follow-up HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in current and + follow-up HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + depth, + stage_channels, + stage_blocks, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + + self.depth = depth + + cur_block = stage_blocks[0] + next_block = stage_blocks[1] + + cur_channel = stage_channels[0] + next_channel = stage_channels[1] + + self.up1 = ResLayer( + BasicBlock, cur_block, cur_channel, cur_channel, norm_cfg=norm_cfg) + + self.low1 = ResLayer( + BasicBlock, + cur_block, + cur_channel, + next_channel, + stride=2, + norm_cfg=norm_cfg) + + if self.depth > 1: + self.low2 = HourglassModule(depth - 1, stage_channels[1:], + stage_blocks[1:]) + else: + self.low2 = ResLayer( + BasicBlock, + next_block, + next_channel, + next_channel, + norm_cfg=norm_cfg) + + self.low3 = ResLayer( + BasicBlock, + cur_block, + next_channel, + cur_channel, + norm_cfg=norm_cfg, + downsample_first=False) + + self.up2 = nn.Upsample(scale_factor=2) + + def forward(self, x): + """Model forward function.""" + up1 = self.up1(x) + low1 = self.low1(x) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + return up1 + up2 + + +@MODELS.register_module() +class HourglassNet(BaseBackbone): + """HourglassNet backbone. + + Stacked Hourglass Networks for Human Pose Estimation. + More details can be found in the `paper + `__ . + + Args: + downsample_times (int): Downsample times in a HourglassModule. + num_stacks (int): Number of HourglassModule modules stacked, + 1 for Hourglass-52, 2 for Hourglass-104. + stage_channels (list[int]): Feature channel of each sub-module in a + HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in a + HourglassModule. + feat_channel (int): Feature channel of conv after a HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import HourglassNet + >>> import torch + >>> self = HourglassNet() + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 256, 128, 128) + (1, 256, 128, 128) + """ + + def __init__( + self, + downsample_times=5, + num_stacks=2, + stage_channels=(256, 256, 384, 384, 384, 512), + stage_blocks=(2, 2, 2, 2, 2, 4), + feat_channel=256, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ], + ): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + + self.num_stacks = num_stacks + assert self.num_stacks >= 1 + assert len(stage_channels) == len(stage_blocks) + assert len(stage_channels) > downsample_times + + cur_channel = stage_channels[0] + + self.stem = nn.Sequential( + ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg), + ResLayer(BasicBlock, 1, 128, 256, stride=2, norm_cfg=norm_cfg)) + + self.hourglass_modules = nn.ModuleList([ + HourglassModule(downsample_times, stage_channels, stage_blocks) + for _ in range(num_stacks) + ]) + + self.inters = ResLayer( + BasicBlock, + num_stacks - 1, + cur_channel, + cur_channel, + norm_cfg=norm_cfg) + + self.conv1x1s = nn.ModuleList([ + ConvModule( + cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) + for _ in range(num_stacks - 1) + ]) + + self.out_convs = nn.ModuleList([ + ConvModule( + cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) + for _ in range(num_stacks) + ]) + + self.remap_convs = nn.ModuleList([ + ConvModule( + feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) + for _ in range(num_stacks - 1) + ]) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Model forward function.""" + inter_feat = self.stem(x) + out_feats = [] + + for ind in range(self.num_stacks): + single_hourglass = self.hourglass_modules[ind] + out_conv = self.out_convs[ind] + + hourglass_feat = single_hourglass(inter_feat) + out_feat = out_conv(hourglass_feat) + out_feats.append(out_feat) + + if ind < self.num_stacks - 1: + inter_feat = self.conv1x1s[ind]( + inter_feat) + self.remap_convs[ind]( + out_feat) + inter_feat = self.inters[ind](self.relu(inter_feat)) + + return out_feats diff --git a/mmpose/models/backbones/hourglass_ae.py b/mmpose/models/backbones/hourglass_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..93e62dd4067c3489de00c5cd1f7875489725de2e --- /dev/null +++ b/mmpose/models/backbones/hourglass_ae.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule, MaxPool2d +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class HourglassAEModule(BaseModule): + """Modified Hourglass Module for HourglassNet_AE backbone. + + Generate module recursively and use BasicBlock as the base unit. + + Args: + depth (int): Depth of current HourglassModule. + stage_channels (list[int]): Feature channels of sub-modules in current + and follow-up HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + depth, + stage_channels, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + + self.depth = depth + + cur_channel = stage_channels[0] + next_channel = stage_channels[1] + + self.up1 = ConvModule( + cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg) + + self.pool1 = MaxPool2d(2, 2) + + self.low1 = ConvModule( + cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg) + + if self.depth > 1: + self.low2 = HourglassAEModule(depth - 1, stage_channels[1:]) + else: + self.low2 = ConvModule( + next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg) + + self.low3 = ConvModule( + next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg) + + self.up2 = nn.UpsamplingNearest2d(scale_factor=2) + + def forward(self, x): + """Model forward function.""" + up1 = self.up1(x) + pool1 = self.pool1(x) + low1 = self.low1(pool1) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + return up1 + up2 + + +@MODELS.register_module() +class HourglassAENet(BaseBackbone): + """Hourglass-AE Network proposed by Newell et al. + + Associative Embedding: End-to-End Learning for Joint + Detection and Grouping. + + More details can be found in the `paper + `__ . + + Args: + downsample_times (int): Downsample times in a HourglassModule. + num_stacks (int): Number of HourglassModule modules stacked, + 1 for Hourglass-52, 2 for Hourglass-104. + stage_channels (list[int]): Feature channel of each sub-module in a + HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in a + HourglassModule. + feat_channels (int): Feature channel of conv after a HourglassModule. + norm_cfg (dict): Dictionary to construct and config norm layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import HourglassAENet + >>> import torch + >>> self = HourglassAENet() + >>> self.eval() + >>> inputs = torch.rand(1, 3, 512, 512) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 34, 128, 128) + """ + + def __init__( + self, + downsample_times=4, + num_stacks=1, + out_channels=34, + stage_channels=(256, 384, 512, 640, 768), + feat_channels=256, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ], + ): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + + self.num_stacks = num_stacks + assert self.num_stacks >= 1 + assert len(stage_channels) > downsample_times + + cur_channels = stage_channels[0] + + self.stem = nn.Sequential( + ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg), + ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg), + MaxPool2d(2, 2), + ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg), + ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg), + ) + + self.hourglass_modules = nn.ModuleList([ + nn.Sequential( + HourglassAEModule( + downsample_times, stage_channels, norm_cfg=norm_cfg), + ConvModule( + feat_channels, + feat_channels, + 3, + padding=1, + norm_cfg=norm_cfg), + ConvModule( + feat_channels, + feat_channels, + 3, + padding=1, + norm_cfg=norm_cfg)) for _ in range(num_stacks) + ]) + + self.out_convs = nn.ModuleList([ + ConvModule( + cur_channels, + out_channels, + 1, + padding=0, + norm_cfg=None, + act_cfg=None) for _ in range(num_stacks) + ]) + + self.remap_out_convs = nn.ModuleList([ + ConvModule( + out_channels, + feat_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None) for _ in range(num_stacks - 1) + ]) + + self.remap_feature_convs = nn.ModuleList([ + ConvModule( + feat_channels, + feat_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None) for _ in range(num_stacks - 1) + ]) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Model forward function.""" + inter_feat = self.stem(x) + out_feats = [] + + for ind in range(self.num_stacks): + single_hourglass = self.hourglass_modules[ind] + out_conv = self.out_convs[ind] + + hourglass_feat = single_hourglass(inter_feat) + out_feat = out_conv(hourglass_feat) + out_feats.append(out_feat) + + if ind < self.num_stacks - 1: + inter_feat = inter_feat + self.remap_out_convs[ind]( + out_feat) + self.remap_feature_convs[ind]( + hourglass_feat) + + return out_feats diff --git a/mmpose/models/backbones/hrformer.py b/mmpose/models/backbones/hrformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4712cfdfb5c0454e00d7e2e09b244424f2c80f5a --- /dev/null +++ b/mmpose/models/backbones/hrformer.py @@ -0,0 +1,763 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import build_dropout +from mmengine.model import BaseModule, trunc_normal_init +from torch.nn.functional import pad + +from mmpose.registry import MODELS +from .hrnet import Bottleneck, HRModule, HRNet + + +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def build_drop_path(drop_path_rate): + """Build drop path layer.""" + return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate)) + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + with_rpe (bool, optional): If True, use relative position bias. + Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + with_rpe=True, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + self.with_rpe = with_rpe + if self.with_rpe: + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_init(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (B*num_windows, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.with_rpe: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class LocalWindowSelfAttention(BaseModule): + r""" Local-window Self Attention (LSA) module with relative position bias. + + This module is the short-range self-attention module in the + Interlaced Sparse Self-Attention `_. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int] | int): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + with_rpe (bool, optional): If True, use relative position bias. + Default: True. + with_pad_mask (bool, optional): If True, mask out the padded tokens in + the attention process. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + with_rpe=True, + with_pad_mask=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(window_size, int): + window_size = (window_size, window_size) + self.window_size = window_size + self.with_pad_mask = with_pad_mask + self.attn = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + with_rpe=with_rpe, + init_cfg=init_cfg) + + def forward(self, x, H, W, **kwargs): + """Forward function.""" + B, N, C = x.shape + x = x.view(B, H, W, C) + Wh, Ww = self.window_size + + # center-pad the feature on H and W axes + pad_h = math.ceil(H / Wh) * Wh - H + pad_w = math.ceil(W / Ww) * Ww - W + x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + + # permute + x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C) + x = x.permute(0, 1, 3, 2, 4, 5) + x = x.reshape(-1, Wh * Ww, C) # (B*num_window, Wh*Ww, C) + + # attention + if self.with_pad_mask and pad_h > 0 and pad_w > 0: + pad_mask = x.new_zeros(1, H, W, 1) + pad_mask = pad( + pad_mask, [ + 0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ], + value=-float('inf')) + pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh, + math.ceil(W / Ww), Ww, 1) + pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5) + pad_mask = pad_mask.reshape(-1, Wh * Ww) + pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1]) + out = self.attn(x, pad_mask, **kwargs) + else: + out = self.attn(x, **kwargs) + + # reverse permutation + out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C) + out = out.permute(0, 1, 3, 2, 4, 5) + out = out.reshape(B, H + pad_h, W + pad_w, C) + + # de-pad + out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2] + return out.reshape(B, N, C) + + +class CrossFFN(BaseModule): + r"""FFN with Depthwise Conv of HRFormer. + + Args: + in_features (int): The feature dimension. + hidden_features (int, optional): The hidden dimension of FFNs. + Defaults: The same as in_features. + act_cfg (dict, optional): Config of activation layer. + Default: dict(type='GELU'). + dw_act_cfg (dict, optional): Config of activation layer appended + right after DW Conv. Default: dict(type='GELU'). + norm_cfg (dict, optional): Config of norm layer. + Default: dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + dw_act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) + self.act1 = build_activation_layer(act_cfg) + self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1] + self.dw3x3 = nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=3, + stride=1, + groups=hidden_features, + padding=1) + self.act2 = build_activation_layer(dw_act_cfg) + self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1] + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) + self.act3 = build_activation_layer(act_cfg) + self.norm3 = build_norm_layer(norm_cfg, out_features)[1] + + # put the modules togather + self.layers = [ + self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2, + self.fc2, self.norm3, self.act3 + ] + + def forward(self, x, H, W): + """Forward function.""" + x = nlc_to_nchw(x, (H, W)) + for layer in self.layers: + x = layer(x) + x = nchw_to_nlc(x) + return x + + +class HRFormerBlock(BaseModule): + """High-Resolution Block for HRFormer. + + Args: + in_features (int): The input dimension. + out_features (int): The output dimension. + num_heads (int): The number of head within each LSA. + window_size (int, optional): The window size for the LSA. + Default: 7 + mlp_ratio (int, optional): The expansion ration of FFN. + Default: 4 + act_cfg (dict, optional): Config of activation layer. + Default: dict(type='GELU'). + norm_cfg (dict, optional): Config of norm layer. + Default: dict(type='SyncBN'). + transformer_norm_cfg (dict, optional): Config of transformer norm + layer. Default: dict(type='LN', eps=1e-6). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + expansion = 1 + + def __init__(self, + in_features, + out_features, + num_heads, + window_size=7, + mlp_ratio=4.0, + drop_path=0.0, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN'), + transformer_norm_cfg=dict(type='LN', eps=1e-6), + init_cfg=None, + **kwargs): + super(HRFormerBlock, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1] + self.attn = LocalWindowSelfAttention( + in_features, + num_heads=num_heads, + window_size=window_size, + init_cfg=None, + **kwargs) + + self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1] + self.ffn = CrossFFN( + in_features=in_features, + hidden_features=int(in_features * mlp_ratio), + out_features=out_features, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dw_act_cfg=act_cfg, + init_cfg=None) + + self.drop_path = build_drop_path( + drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + """Forward function.""" + B, C, H, W = x.size() + # Attention + x = x.view(B, C, -1).permute(0, 2, 1) + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + # FFN + x = x + self.drop_path(self.ffn(self.norm2(x), H, W)) + x = x.permute(0, 2, 1).view(B, C, H, W) + return x + + def extra_repr(self): + """(Optional) Set the extra information about this module.""" + return 'num_heads={}, window_size={}, mlp_ratio={}'.format( + self.num_heads, self.window_size, self.mlp_ratio) + + +class HRFomerModule(HRModule): + """High-Resolution Module for HRFormer. + + Args: + num_branches (int): The number of branches in the HRFormerModule. + block (nn.Module): The building block of HRFormer. + The block should be the HRFormerBlock. + num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + num_inchannels (tuple): The number of input channels in each branch. + The length must be equal to num_branches. + num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + num_heads (tuple): The number of heads within the LSAs. + num_window_sizes (tuple): The window size for the LSAs. + num_mlp_ratios (tuple): The expansion ratio for the FFNs. + drop_path (int, optional): The drop path rate of HRFomer. + Default: 0.0 + multiscale_output (bool, optional): Whether to output multi-level + features produced by multiple branches. If False, only the first + level feature will be output. Default: True. + conv_cfg (dict, optional): Config of the conv layers. + Default: None. + norm_cfg (dict, optional): Config of the norm layers appended + right after conv. Default: dict(type='SyncBN', requires_grad=True) + transformer_norm_cfg (dict, optional): Config of the norm layers. + Default: dict(type='LN', eps=1e-6) + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False + upsample_cfg(dict, optional): The config of upsample layers in fuse + layers. Default: dict(mode='bilinear', align_corners=False) + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + multiscale_output=True, + drop_paths=0.0, + with_rpe=True, + with_pad_mask=False, + conv_cfg=None, + norm_cfg=dict(type='SyncBN', requires_grad=True), + transformer_norm_cfg=dict(type='LN', eps=1e-6), + with_cp=False, + upsample_cfg=dict(mode='bilinear', align_corners=False), + **kwargs): + + self.transformer_norm_cfg = transformer_norm_cfg + self.drop_paths = drop_paths + self.num_heads = num_heads + self.num_window_sizes = num_window_sizes + self.num_mlp_ratios = num_mlp_ratios + self.with_rpe = with_rpe + self.with_pad_mask = with_pad_mask + + super().__init__(num_branches, block, num_blocks, num_inchannels, + num_channels, multiscale_output, with_cp, conv_cfg, + norm_cfg, upsample_cfg, **kwargs) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + # HRFormerBlock does not support down sample layer yet. + assert stride == 1 and self.in_channels[branch_index] == num_channels[ + branch_index] + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + num_heads=self.num_heads[branch_index], + window_size=self.num_window_sizes[branch_index], + mlp_ratio=self.num_mlp_ratios[branch_index], + drop_path=self.drop_paths[0], + norm_cfg=self.norm_cfg, + transformer_norm_cfg=self.transformer_norm_cfg, + init_cfg=None, + with_rpe=self.with_rpe, + with_pad_mask=self.with_pad_mask)) + + self.in_channels[ + branch_index] = self.in_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + num_heads=self.num_heads[branch_index], + window_size=self.num_window_sizes[branch_index], + mlp_ratio=self.num_mlp_ratios[branch_index], + drop_path=self.drop_paths[i], + norm_cfg=self.norm_cfg, + transformer_norm_cfg=self.transformer_norm_cfg, + init_cfg=None, + with_rpe=self.with_rpe, + with_pad_mask=self.with_pad_mask)) + return nn.Sequential(*layers) + + def _make_fuse_layers(self): + """Build fuse layers.""" + if self.num_branches == 1: + return None + num_branches = self.num_branches + num_inchannels = self.in_channels + fuse_layers = [] + for i in range(num_branches if self.multiscale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_inchannels[j], + num_inchannels[i], + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_inchannels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), + mode=self.upsample_cfg['mode'], + align_corners=self. + upsample_cfg['align_corners']))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + with_out_act = False + else: + num_outchannels_conv3x3 = num_inchannels[j] + with_out_act = True + sub_modules = [ + build_conv_layer( + self.conv_cfg, + num_inchannels[j], + num_inchannels[j], + kernel_size=3, + stride=2, + padding=1, + groups=num_inchannels[j], + bias=False, + ), + build_norm_layer(self.norm_cfg, + num_inchannels[j])[1], + build_conv_layer( + self.conv_cfg, + num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=1, + stride=1, + bias=False, + ), + build_norm_layer(self.norm_cfg, + num_outchannels_conv3x3)[1] + ] + if with_out_act: + sub_modules.append(nn.ReLU(False)) + conv3x3s.append(nn.Sequential(*sub_modules)) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + """Return the number of input channels.""" + return self.in_channels + + +@MODELS.register_module() +class HRFormer(HRNet): + """HRFormer backbone. + + This backbone is the implementation of `HRFormer: High-Resolution + Transformer for Dense Prediction `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of block. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Config of norm layer. + Use `SyncBN` by default. + transformer_norm_cfg (dict): Config of transformer norm layer. + Use `LN` by default. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import HRFormer + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(2, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='HRFORMER', + >>> window_sizes=(7, 7), + >>> num_heads=(1, 2), + >>> mlp_ratios=(4, 4), + >>> num_blocks=(2, 2), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='HRFORMER', + >>> window_sizes=(7, 7, 7), + >>> num_heads=(1, 2, 4), + >>> mlp_ratios=(4, 4, 4), + >>> num_blocks=(2, 2, 2), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=2, + >>> num_branches=4, + >>> block='HRFORMER', + >>> window_sizes=(7, 7, 7, 7), + >>> num_heads=(1, 2, 4, 8), + >>> mlp_ratios=(4, 4, 4, 4), + >>> num_blocks=(2, 2, 2, 2), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRFormer(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock} + + def __init__( + self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + transformer_norm_cfg=dict(type='LN', eps=1e-6), + norm_eval=False, + with_cp=False, + zero_init_residual=False, + frozen_stages=-1, + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ], + ): + + # stochastic depth + depths = [ + extra[stage]['num_blocks'][0] * extra[stage]['num_modules'] + for stage in ['stage2', 'stage3', 'stage4'] + ] + depth_s2, depth_s3, _ = depths + drop_path_rate = extra['drop_path_rate'] + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + extra['stage2']['drop_path_rates'] = dpr[0:depth_s2] + extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3] + extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:] + + # HRFormer use bilinear upsample as default + upsample_cfg = extra.get('upsample', { + 'mode': 'bilinear', + 'align_corners': False + }) + extra['upsample'] = upsample_cfg + self.transformer_norm_cfg = transformer_norm_cfg + self.with_rpe = extra.get('with_rpe', True) + self.with_pad_mask = extra.get('with_pad_mask', False) + + super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval, + with_cp, zero_init_residual, frozen_stages, init_cfg) + + def _make_stage(self, + layer_config, + num_inchannels, + multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + num_heads = layer_config['num_heads'] + num_window_sizes = layer_config['window_sizes'] + num_mlp_ratios = layer_config['mlp_ratios'] + drop_path_rates = layer_config['drop_path_rates'] + + modules = [] + for i in range(num_modules): + # multiscale_output is only used at the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + modules.append( + HRFomerModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + reset_multiscale_output, + drop_paths=drop_path_rates[num_blocks[0] * + i:num_blocks[0] * (i + 1)], + with_rpe=self.with_rpe, + with_pad_mask=self.with_pad_mask, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + transformer_norm_cfg=self.transformer_norm_cfg, + with_cp=self.with_cp, + upsample_cfg=self.upsample_cfg)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels diff --git a/mmpose/models/backbones/hrnet.py b/mmpose/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..381b22d60ec886ecb6d8c52fc9e7ccab52c05e99 --- /dev/null +++ b/mmpose/models/backbones/hrnet.py @@ -0,0 +1,610 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, constant_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .resnet import BasicBlock, Bottleneck, get_expansion + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=False, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + upsample_cfg=dict(mode='nearest', align_corners=None), + init_cfg=None): + + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.upsample_cfg = upsample_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + @staticmethod + def _check_branches(num_branches, num_blocks, in_channels, num_channels): + """Check input to avoid ValueError.""" + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_BLOCKS({len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_CHANNELS({len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Make one branch.""" + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * get_expansion(block): + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * get_expansion(block), + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer( + self.norm_cfg, + num_channels[branch_index] * get_expansion(block))[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index] * get_expansion(block), + stride=stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * get_expansion(block) + for _ in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index] * get_expansion(block), + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + """Make branches.""" + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + """Make fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), + mode=self.upsample_cfg['mode'], + align_corners=self. + upsample_cfg['align_corners']))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseBackbone): + """HRNet backbone. + + `High-Resolution Representations for Labeling Pixels and Regions + `__ + + Args: + extra (dict): detailed configuration for each stage of HRNet. + in_channels (int): Number of input image channels. Default: 3. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__( + self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=False, + with_cp=False, + zero_init_residual=False, + frozen_stages=-1, + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ], + ): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.init_cfg = init_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + self.frozen_stages = frozen_stages + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + self.upsample_cfg = self.extra.get('upsample', { + 'mode': 'nearest', + 'align_corners': None + }) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * get_expansion(block) + self.layer1 = self._make_layer(block, 64, stage1_out_channels, + num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [ + channel * get_expansion(block) for channel in num_channels + ] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [ + channel * get_expansion(block) for channel in num_channels + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [ + channel * get_expansion(block) for channel in num_channels + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, + num_channels, + multiscale_output=self.stage4_cfg.get('multiscale_output', False)) + + self._freeze_stages() + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, in_channels, out_channels, blocks, stride=1): + """Make layer.""" + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1]) + + layers = [] + layers.append( + block( + in_channels, + out_channels, + stride=stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + for _ in range(1, blocks): + layers.append( + block( + out_channels, + out_channels, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + """Make stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + upsample_cfg=self.upsample_cfg)) + + in_channels = hr_modules[-1].in_channels + + return nn.Sequential(*hr_modules), in_channels + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.norm1.eval() + self.norm2.eval() + + for m in [self.conv1, self.norm1, self.conv2, self.norm2]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if i == 1: + m = getattr(self, 'layer1') + else: + m = getattr(self, f'stage{i}') + + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i < 4: + m = getattr(self, f'transition{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone.""" + super(HRNet, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress zero_init_residual if use pretrained model. + return + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return tuple(y_list) + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/litehrnet.py b/mmpose/models/backbones/litehrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad5f63014553129a02ca3dc4bfda4c181fcd6a6 --- /dev/null +++ b/mmpose/models/backbones/litehrnet.py @@ -0,0 +1,999 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/HRNet/Lite-HRNet +# Original licence: Apache License 2.0. +# ------------------------------------------------------------------------------ + +import mmengine +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_conv_layer, build_norm_layer) +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .utils import channel_shuffle + + +class SpatialWeighting(BaseModule): + """Spatial weighting module. + + Args: + channels (int): The channels of the module. + ratio (int): channel reduction ratio. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + act_cfg (dict): Config dict for activation layer. + Default: (dict(type='ReLU'), dict(type='Sigmoid')). + The last ConvModule uses Sigmoid by default. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + norm_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmengine.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out + + +class CrossResolutionWeighting(BaseModule): + """Cross-resolution channel weighting module. + + Args: + channels (int): The channels of the module. + ratio (int): channel reduction ratio. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + act_cfg (dict): Config dict for activation layer. + Default: (dict(type='ReLU'), dict(type='Sigmoid')). + The last ConvModule uses Sigmoid by default. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + norm_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmengine.is_tuple_of(act_cfg, dict) + self.channels = channels + total_channel = sum(channels) + self.conv1 = ConvModule( + in_channels=total_channel, + out_channels=int(total_channel / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(total_channel / ratio), + out_channels=total_channel, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + mini_size = x[-1].size()[-2:] + out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] + out = torch.cat(out, dim=1) + out = self.conv1(out) + out = self.conv2(out) + out = torch.split(out, self.channels, dim=1) + out = [ + s * F.interpolate(a, size=s.size()[-2:], mode='nearest') + for s, a in zip(x, out) + ] + return out + + +class ConditionalChannelWeighting(BaseModule): + """Conditional channel weighting block. + + Args: + in_channels (int): The input channels of the block. + stride (int): Stride of the 3x3 convolution layer. + reduce_ratio (int): channel reduction ratio. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + stride, + reduce_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.with_cp = with_cp + self.stride = stride + assert stride in [1, 2] + + branch_channels = [channel // 2 for channel in in_channels] + + self.cross_resolution_weighting = CrossResolutionWeighting( + branch_channels, + ratio=reduce_ratio, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + self.depthwise_convs = nn.ModuleList([ + ConvModule( + channel, + channel, + kernel_size=3, + stride=self.stride, + padding=1, + groups=channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) for channel in branch_channels + ]) + + self.spatial_weighting = nn.ModuleList([ + SpatialWeighting(channels=channel, ratio=4) + for channel in branch_channels + ]) + + def forward(self, x): + + def _inner_forward(x): + x = [s.chunk(2, dim=1) for s in x] + x1 = [s[0] for s in x] + x2 = [s[1] for s in x] + + x2 = self.cross_resolution_weighting(x2) + x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] + x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] + + out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] + out = [channel_shuffle(s, 2) for s in out] + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class Stem(BaseModule): + """Stem network block. + + Args: + in_channels (int): The input channels of the block. + stem_channels (int): Output channels of the stem layer. + out_channels (int): The output channels of the block. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + stem_channels, + out_channels, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='ReLU')) + + mid_channels = int(round(stem_channels * expand_ratio)) + branch_channels = stem_channels // 2 + if stem_channels == self.out_channels: + inc_channels = self.out_channels - branch_channels + else: + inc_channels = self.out_channels - stem_channels + + self.branch1 = nn.Sequential( + ConvModule( + branch_channels, + branch_channels, + kernel_size=3, + stride=2, + padding=1, + groups=branch_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_channels, + inc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')), + ) + + self.expand_conv = ConvModule( + branch_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.depthwise_conv = ConvModule( + mid_channels, + mid_channels, + kernel_size=3, + stride=2, + padding=1, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + self.linear_conv = ConvModule( + mid_channels, + branch_channels + if stem_channels == self.out_channels else stem_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + + def forward(self, x): + + def _inner_forward(x): + x = self.conv1(x) + x1, x2 = x.chunk(2, dim=1) + + x2 = self.expand_conv(x2) + x2 = self.depthwise_conv(x2) + x2 = self.linear_conv(x2) + + out = torch.cat((self.branch1(x1), x2), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class IterativeHead(BaseModule): + """Extra iterative head for feature learning. + + Args: + in_channels (int): The input channels of the block. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, in_channels, norm_cfg=dict(type='BN'), init_cfg=None): + super().__init__(init_cfg=init_cfg) + projects = [] + num_branchs = len(in_channels) + self.in_channels = in_channels[::-1] + + for i in range(num_branchs): + if i != num_branchs - 1: + projects.append( + DepthwiseSeparableConvModule( + in_channels=self.in_channels[i], + out_channels=self.in_channels[i + 1], + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + dw_act_cfg=None, + pw_act_cfg=dict(type='ReLU'))) + else: + projects.append( + DepthwiseSeparableConvModule( + in_channels=self.in_channels[i], + out_channels=self.in_channels[i], + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + dw_act_cfg=None, + pw_act_cfg=dict(type='ReLU'))) + self.projects = nn.ModuleList(projects) + + def forward(self, x): + x = x[::-1] + + y = [] + last_x = None + for i, s in enumerate(x): + if last_x is not None: + last_x = F.interpolate( + last_x, + size=s.size()[-2:], + mode='bilinear', + align_corners=True) + s = s + last_x + s = self.projects[i](s) + y.append(s) + last_x = s + + return y[::-1] + + +class ShuffleUnit(BaseModule): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class LiteHRModule(BaseModule): + """High-Resolution Module for LiteHRNet. + + It contains conditional channel weighting blocks and + shuffle blocks. + + + Args: + num_branches (int): Number of branches in the module. + num_blocks (int): Number of blocks in the module. + in_channels (list(int)): Number of input image channels. + reduce_ratio (int): Channel reduction ratio. + module_type (str): 'LITE' or 'NAIVE' + multiscale_output (bool): Whether to output multi-scale features. + with_fuse (bool): Whether to use fuse layers. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=False, + with_fuse=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self._check_branches(num_branches, in_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.module_type = module_type + self.multiscale_output = multiscale_output + self.with_fuse = with_fuse + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + + if self.module_type.upper() == 'LITE': + self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio) + elif self.module_type.upper() == 'NAIVE': + self.layers = self._make_naive_branches(num_branches, num_blocks) + else: + raise ValueError("module_type should be either 'LITE' or 'NAIVE'.") + if self.with_fuse: + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU() + + def _check_branches(self, num_branches, in_channels): + """Check input to avoid ValueError.""" + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1): + """Make channel weighting blocks.""" + layers = [] + for i in range(num_blocks): + layers.append( + ConditionalChannelWeighting( + self.in_channels, + stride=stride, + reduce_ratio=reduce_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp)) + + return nn.Sequential(*layers) + + def _make_one_branch(self, branch_index, num_blocks, stride=1): + """Make one branch.""" + layers = [] + layers.append( + ShuffleUnit( + self.in_channels[branch_index], + self.in_channels[branch_index], + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='ReLU'), + with_cp=self.with_cp)) + for i in range(1, num_blocks): + layers.append( + ShuffleUnit( + self.in_channels[branch_index], + self.in_channels[branch_index], + stride=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='ReLU'), + with_cp=self.with_cp)) + + return nn.Sequential(*layers) + + def _make_naive_branches(self, num_branches, num_blocks): + """Make branches.""" + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, num_blocks)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + """Make fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=in_channels[j], + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=in_channels[j], + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.layers[0](x[0])] + + if self.module_type.upper() == 'LITE': + out = self.layers(x) + elif self.module_type.upper() == 'NAIVE': + for i in range(self.num_branches): + x[i] = self.layers[i](x[i]) + out = x + + if self.with_fuse: + out_fuse = [] + for i in range(len(self.fuse_layers)): + # `y = 0` will lead to decreased accuracy (0.5~1 mAP) + y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) + for j in range(self.num_branches): + if i == j: + y += out[j] + else: + y += self.fuse_layers[i][j](out[j]) + out_fuse.append(self.relu(y)) + out = out_fuse + if not self.multiscale_output: + out = [out[0]] + return out + + +@MODELS.register_module() +class LiteHRNet(BaseBackbone): + """Lite-HRNet backbone. + + `Lite-HRNet: A Lightweight High-Resolution Network + `_. + + Code adapted from 'https://github.com/HRNet/Lite-HRNet'. + + Args: + extra (dict): detailed configuration for each stage of HRNet. + in_channels (int): Number of input image channels. Default: 3. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import LiteHRNet + >>> import torch + >>> extra=dict( + >>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), + >>> num_stages=3, + >>> stages_spec=dict( + >>> num_modules=(2, 4, 2), + >>> num_branches=(2, 3, 4), + >>> num_blocks=(2, 2, 2), + >>> module_type=('LITE', 'LITE', 'LITE'), + >>> with_fuse=(True, True, True), + >>> reduce_ratios=(8, 8, 8), + >>> num_channels=( + >>> (40, 80), + >>> (40, 80, 160), + >>> (40, 80, 160, 320), + >>> )), + >>> with_head=False) + >>> self = LiteHRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 40, 8, 8) + """ + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super().__init__(init_cfg=init_cfg) + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.stem = Stem( + in_channels, + stem_channels=self.extra['stem']['stem_channels'], + out_channels=self.extra['stem']['out_channels'], + expand_ratio=self.extra['stem']['expand_ratio'], + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + + self.num_stages = self.extra['num_stages'] + self.stages_spec = self.extra['stages_spec'] + + num_channels_last = [ + self.stem.out_channels, + ] + for i in range(self.num_stages): + num_channels = self.stages_spec['num_channels'][i] + num_channels = [num_channels[i] for i in range(len(num_channels))] + setattr( + self, f'transition{i}', + self._make_transition_layer(num_channels_last, num_channels)) + + stage, num_channels_last = self._make_stage( + self.stages_spec, i, num_channels, multiscale_output=True) + setattr(self, f'stage{i}', stage) + + self.with_head = self.extra['with_head'] + if self.with_head: + self.head_layer = IterativeHead( + in_channels=num_channels_last, + norm_cfg=self.norm_cfg, + ) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_pre_layer[i], + kernel_size=3, + stride=1, + padding=1, + groups=num_channels_pre_layer[i], + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_pre_layer[i])[1], + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU())) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + bias=False), + build_norm_layer(self.norm_cfg, in_channels)[1], + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU())) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_stage(self, + stages_spec, + stage_index, + in_channels, + multiscale_output=True): + num_modules = stages_spec['num_modules'][stage_index] + num_branches = stages_spec['num_branches'][stage_index] + num_blocks = stages_spec['num_blocks'][stage_index] + reduce_ratio = stages_spec['reduce_ratios'][stage_index] + with_fuse = stages_spec['with_fuse'][stage_index] + module_type = stages_spec['module_type'][stage_index] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + modules.append( + LiteHRModule( + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=reset_multiscale_output, + with_fuse=with_fuse, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp)) + in_channels = modules[-1].in_channels + + return nn.Sequential(*modules), in_channels + + def forward(self, x): + """Forward function.""" + x = self.stem(x) + + y_list = [x] + for i in range(self.num_stages): + x_list = [] + transition = getattr(self, f'transition{i}') + for j in range(self.stages_spec['num_branches'][i]): + if transition[j]: + if j >= len(y_list): + x_list.append(transition[j](y_list[-1])) + else: + x_list.append(transition[j](y_list[j])) + else: + x_list.append(y_list[j]) + y_list = getattr(self, f'stage{i}')(x_list) + + x = y_list + if self.with_head: + x = self.head_layer(x) + + return (x[0], ) + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/mobilenet_v2.py b/mmpose/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b64c0d73d41d3763018a8e46621c6ab695be6856 --- /dev/null +++ b/mmpose/models/backbones/mobilenet_v2.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .utils import make_divisible + + +class InvertedResidual(BaseModule): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__(init_cfg=init_cfg) + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class MobileNetV2(BaseBackbone): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__(init_cfg=init_cfg) + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/mobilenet_v3.py b/mmpose/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..03ecf90dd22d42a3650a4eac00c070ec556c7912 --- /dev/null +++ b/mmpose/models/backbones/mobilenet_v3.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .utils import InvertedResidual + + +@MODELS.register_module() +class MobileNetV3(BaseBackbone): + """MobileNetV3 backbone. + + Args: + arch (str): Architecture of mobilnetv3, from {small, big}. + Default: small. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (None or Sequence[int]): Output from which stages. + Default: (-1, ), which means output tensors from final stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm']) + ]`` + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'big': [[3, 16, 16, False, 'ReLU', 1], + [3, 64, 24, False, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(-1, ), + frozen_stages=-1, + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + assert arch in self.arch_settings + for index in out_indices: + if index not in range(-len(self.arch_settings[arch]), + len(self.arch_settings[arch])): + raise ValueError('the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch])): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = 16 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type='HSwish')) + + self.layers = self._make_layer() + self.feat_dim = self.arch_settings[arch][-1][2] + + def _make_layer(self): + layers = [] + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=1.0, divisor=2.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=True, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + self.in_channels = out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + return layers + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices or \ + i - len(self.layers) in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/mspn.py b/mmpose/models/backbones/mspn.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb636b1a3fdc0357fa7dc7c3751738914d58980 --- /dev/null +++ b/mmpose/models/backbones/mspn.py @@ -0,0 +1,541 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy as cp +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, MaxPool2d +from mmengine.model import BaseModule +from mmengine.runner import load_state_dict + +from mmpose.registry import MODELS +from mmpose.utils import get_root_logger +from .base_backbone import BaseBackbone +from .resnet import Bottleneck as _Bottleneck +from .utils import get_state_dict + + +class Bottleneck(_Bottleneck): + expansion = 4 + """Bottleneck block for MSPN. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + stride (int): stride of the block. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels * 4, **kwargs) + + +class DownsampleModule(BaseModule): + """Downsample module for MSPN. + + Args: + block (nn.Module): Downsample block. + num_blocks (list): Number of blocks in each downsample unit. + num_units (int): Numbers of downsample units. Default: 4 + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the input feature to + downsample module. Default: 64 + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + block, + num_blocks, + num_units=4, + has_skip=False, + norm_cfg=dict(type='BN'), + in_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.has_skip = has_skip + self.in_channels = in_channels + assert len(num_blocks) == num_units + self.num_blocks = num_blocks + self.num_units = num_units + self.norm_cfg = norm_cfg + self.layer1 = self._make_layer(block, in_channels, num_blocks[0]) + for i in range(1, num_units): + module_name = f'layer{i + 1}' + self.add_module( + module_name, + self._make_layer( + block, in_channels * pow(2, i), num_blocks[i], stride=2)) + + def _make_layer(self, block, out_channels, blocks, stride=1): + downsample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = ConvModule( + self.in_channels, + out_channels * block.expansion, + kernel_size=1, + stride=stride, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + units = list() + units.append( + block( + self.in_channels, + out_channels, + stride=stride, + downsample=downsample, + norm_cfg=self.norm_cfg)) + self.in_channels = out_channels * block.expansion + for _ in range(1, blocks): + units.append(block(self.in_channels, out_channels)) + + return nn.Sequential(*units) + + def forward(self, x, skip1, skip2): + out = list() + for i in range(self.num_units): + module_name = f'layer{i + 1}' + module_i = getattr(self, module_name) + x = module_i(x) + if self.has_skip: + x = x + skip1[i] + skip2[i] + out.append(x) + out.reverse() + + return tuple(out) + + +class UpsampleUnit(BaseModule): + """Upsample unit for upsample module. + + Args: + ind (int): Indicates whether to interpolate (>0) and whether to + generate feature map for the next hourglass-like module. + num_units (int): Number of units that form a upsample module. Along + with ind and gen_cross_conv, nm_units is used to decide whether + to generate feature map for the next hourglass-like module. + in_channels (int): Channel number of the skip-in feature maps from + the corresponding downsample unit. + unit_channels (int): Channel number in this unit. Default:256. + gen_skip: (bool): Whether or not to generate skips for the posterior + downsample module. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (int): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + ind, + num_units, + in_channels, + unit_channels=256, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.num_units = num_units + self.norm_cfg = norm_cfg + self.in_skip = ConvModule( + in_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + self.relu = nn.ReLU(inplace=True) + + self.ind = ind + if self.ind > 0: + self.up_conv = ConvModule( + unit_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + self.gen_skip = gen_skip + if self.gen_skip: + self.out_skip1 = ConvModule( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.out_skip2 = ConvModule( + unit_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.gen_cross_conv = gen_cross_conv + if self.ind == num_units - 1 and self.gen_cross_conv: + self.cross_conv = ConvModule( + unit_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + def forward(self, x, up_x): + out = self.in_skip(x) + + if self.ind > 0: + up_x = F.interpolate( + up_x, + size=(x.size(2), x.size(3)), + mode='bilinear', + align_corners=True) + up_x = self.up_conv(up_x) + out = out + up_x + out = self.relu(out) + + skip1 = None + skip2 = None + if self.gen_skip: + skip1 = self.out_skip1(x) + skip2 = self.out_skip2(out) + + cross_conv = None + if self.ind == self.num_units - 1 and self.gen_cross_conv: + cross_conv = self.cross_conv(out) + + return out, skip1, skip2, cross_conv + + +class UpsampleModule(BaseModule): + """Upsample module for MSPN. + + Args: + unit_channels (int): Channel number in the upsample units. + Default:256. + num_units (int): Numbers of upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (int): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + unit_channels=256, + num_units=4, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.in_channels = list() + for i in range(num_units): + self.in_channels.append(Bottleneck.expansion * out_channels * + pow(2, i)) + self.in_channels.reverse() + self.num_units = num_units + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.norm_cfg = norm_cfg + for i in range(num_units): + module_name = f'up{i + 1}' + self.add_module( + module_name, + UpsampleUnit( + i, + self.num_units, + self.in_channels[i], + unit_channels, + self.gen_skip, + self.gen_cross_conv, + norm_cfg=self.norm_cfg, + out_channels=64)) + + def forward(self, x): + out = list() + skip1 = list() + skip2 = list() + cross_conv = None + for i in range(self.num_units): + module_i = getattr(self, f'up{i + 1}') + if i == 0: + outi, skip1_i, skip2_i, _ = module_i(x[i], None) + elif i == self.num_units - 1: + outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1]) + else: + outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1]) + out.append(outi) + skip1.append(skip1_i) + skip2.append(skip2_i) + skip1.reverse() + skip2.reverse() + + return out, skip1, skip2, cross_conv + + +class SingleStageNetwork(BaseModule): + """Single_stage Network. + + Args: + unit_channels (int): Channel number in the upsample units. Default:256. + num_units (int): Numbers of downsample/upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + num_blocks (list): Number of blocks in each downsample unit. + Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks) + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the feature from ResNetTop. + Default: 64. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + has_skip=False, + gen_skip=False, + gen_cross_conv=False, + unit_channels=256, + num_units=4, + num_blocks=[2, 2, 2, 2], + norm_cfg=dict(type='BN'), + in_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__(init_cfg=init_cfg) + assert len(num_blocks) == num_units + self.has_skip = has_skip + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.num_units = num_units + self.unit_channels = unit_channels + self.num_blocks = num_blocks + self.norm_cfg = norm_cfg + + self.downsample = DownsampleModule(Bottleneck, num_blocks, num_units, + has_skip, norm_cfg, in_channels) + self.upsample = UpsampleModule(unit_channels, num_units, gen_skip, + gen_cross_conv, norm_cfg, in_channels) + + def forward(self, x, skip1, skip2): + mid = self.downsample(x, skip1, skip2) + out, skip1, skip2, cross_conv = self.upsample(mid) + + return out, skip1, skip2, cross_conv + + +class ResNetTop(BaseModule): + """ResNet top for MSPN. + + Args: + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + channels (int): Number of channels of the feature output by ResNetTop. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, norm_cfg=dict(type='BN'), channels=64, init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.top = nn.Sequential( + ConvModule( + 3, + channels, + kernel_size=7, + stride=2, + padding=3, + norm_cfg=norm_cfg, + inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1)) + + def forward(self, img): + return self.top(img) + + +@MODELS.register_module() +class MSPN(BaseBackbone): + """MSPN backbone. Paper ref: Li et al. "Rethinking on Multi-Stage Networks + for Human Pose Estimation" (CVPR 2020). + + Args: + unit_channels (int): Number of Channels in an upsample unit. + Default: 256 + num_stages (int): Number of stages in a multi-stage MSPN. Default: 4 + num_units (int): Number of downsample/upsample units in a single-stage + network. Default: 4 + Note: Make sure num_units == len(self.num_blocks) + num_blocks (list): Number of bottlenecks in each + downsample unit. Default: [2, 2, 2, 2] + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + res_top_channels (int): Number of channels of feature from ResNetTop. + Default: 64. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict( + type='Normal', + std=0.01, + layer=['Linear']), + ]`` + + Example: + >>> from mmpose.models import MSPN + >>> import torch + >>> self = MSPN(num_stages=2,num_units=2,num_blocks=[2,2]) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... for feature in level_output: + ... print(tuple(feature.shape)) + ... + (1, 256, 64, 64) + (1, 256, 128, 128) + (1, 256, 64, 64) + (1, 256, 128, 128) + """ + + def __init__(self, + unit_channels=256, + num_stages=4, + num_units=4, + num_blocks=[2, 2, 2, 2], + norm_cfg=dict(type='BN'), + res_top_channels=64, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Normal', std=0.01, layer=['Linear']), + ]): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__(init_cfg=init_cfg) + self.unit_channels = unit_channels + self.num_stages = num_stages + self.num_units = num_units + self.num_blocks = num_blocks + self.norm_cfg = norm_cfg + + assert self.num_stages > 0 + assert self.num_units > 1 + assert self.num_units == len(self.num_blocks) + self.top = ResNetTop(norm_cfg=norm_cfg) + self.multi_stage_mspn = nn.ModuleList([]) + for i in range(self.num_stages): + if i == 0: + has_skip = False + else: + has_skip = True + if i != self.num_stages - 1: + gen_skip = True + gen_cross_conv = True + else: + gen_skip = False + gen_cross_conv = False + self.multi_stage_mspn.append( + SingleStageNetwork(has_skip, gen_skip, gen_cross_conv, + unit_channels, num_units, num_blocks, + norm_cfg, res_top_channels)) + + def forward(self, x): + """Model forward function.""" + out_feats = [] + skip1 = None + skip2 = None + x = self.top(x) + for i in range(self.num_stages): + out, skip1, skip2, x = self.multi_stage_mspn[i](x, skip1, skip2) + out_feats.append(out) + + return out_feats + + def init_weights(self): + """Initialize model weights.""" + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + logger = get_root_logger() + state_dict_tmp = get_state_dict(self.init_cfg['checkpoint']) + state_dict = OrderedDict() + state_dict['top'] = OrderedDict() + state_dict['bottlenecks'] = OrderedDict() + for k, v in state_dict_tmp.items(): + if k.startswith('layer'): + if 'downsample.0' in k: + state_dict['bottlenecks'][k.replace( + 'downsample.0', 'downsample.conv')] = v + elif 'downsample.1' in k: + state_dict['bottlenecks'][k.replace( + 'downsample.1', 'downsample.bn')] = v + else: + state_dict['bottlenecks'][k] = v + elif k.startswith('conv1'): + state_dict['top'][k.replace('conv1', 'top.0.conv')] = v + elif k.startswith('bn1'): + state_dict['top'][k.replace('bn1', 'top.0.bn')] = v + + load_state_dict( + self.top, state_dict['top'], strict=False, logger=logger) + for i in range(self.num_stages): + load_state_dict( + self.multi_stage_mspn[i].downsample, + state_dict['bottlenecks'], + strict=False, + logger=logger) + else: + super(MSPN, self).init_weights() diff --git a/mmpose/models/backbones/pvt.py b/mmpose/models/backbones/pvt.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2b6495482b4feadd86f51fa11b64ee10878fef --- /dev/null +++ b/mmpose/models/backbones/pvt.py @@ -0,0 +1,569 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.model.weight_init import trunc_normal_ +from mmengine.runner import load_state_dict +from mmengine.utils import to_2tuple + +from mmpose.registry import MODELS +from ...utils import get_root_logger +from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw, pvt_convert +from .utils import get_state_dict + + +class MixFFN(BaseModule): + """An implementation of MixFFN of PVT. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Depth-wise Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + Default: None. + use_conv (bool): If True, add 3x3 DWConv between two Linear layers. + Defaults: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + use_conv=False, + init_cfg=None): + super(MixFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + if use_conv: + # 3x3 depth wise conv to provide positional encode information + dw_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, activate, drop, fc2, drop] + if use_conv: + layers.insert(1, dw_conv) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class SpatialReductionAttention(MultiheadAttention): + """An implementation of Spatial Reduction Attention of PVT. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default: True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention of PVT. Default: 1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type='LN'), + sr_ratio=1, + init_cfg=None): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + batch_first=batch_first, + dropout_layer=dropout_layer, + bias=qkv_bias, + init_cfg=init_cfg) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmpose import digit_version, mmcv_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'SpatialReductionAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class PVTEncoderLayer(BaseModule): + """Implements one encoder layer in PVT. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default: 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention of PVT. Default: 1. + use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1, + use_conv_ffn=False, + init_cfg=None): + super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg) + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = SpatialReductionAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + use_conv=use_conv_ffn, + act_cfg=act_cfg) + + def forward(self, x, hw_shape): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + + return x + + +class AbsolutePositionEmbedding(BaseModule): + """An implementation of the absolute position embedding in PVT. + + Args: + pos_shape (int): The shape of the absolute position embedding. + pos_dim (int): The dimension of the absolute position embedding. + drop_rate (float): Probability of an element to be zeroed. + Default: 0.0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(pos_shape, int): + pos_shape = to_2tuple(pos_shape) + elif isinstance(pos_shape, tuple): + if len(pos_shape) == 1: + pos_shape = to_2tuple(pos_shape[0]) + assert len(pos_shape) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pos_shape)}' + self.pos_shape = pos_shape + self.pos_dim = pos_dim + + self.pos_embed = nn.Parameter( + torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim)) + self.drop = nn.Dropout(p=drop_rate) + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + + def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'): + """Resize pos_embed weights. + + Resize pos_embed using bilinear interpolate method. + + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shape (tuple): Tuple for (downsampled input image height, + downsampled input image width). + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'bilinear'``. + + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C]. + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = self.pos_shape + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous() + pos_embed_weight = F.interpolate( + pos_embed_weight, size=input_shape, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, + 2).transpose(1, 2).contiguous() + pos_embed = pos_embed_weight + + return pos_embed + + def forward(self, x, hw_shape, mode='bilinear'): + pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode) + return self.drop(x + pos_embed) + + +@MODELS.register_module() +class PyramidVisionTransformer(BaseModule): + """Pyramid Vision Transformer (PVT) + + Implementation of `Pyramid Vision Transformer: A Versatile Backbone for + Dense Prediction without Convolutions + `_. + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 64. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 5, 8]. + patch_sizes (Sequence[int]): The patch_size of each patch embedding. + Default: [4, 2, 2, 2]. + strides (Sequence[int]): The stride of each patch embedding. + Default: [4, 2, 2, 2]. + paddings (Sequence[int]): The padding of each patch embedding. + Default: [0, 0, 0, 0]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the + embedding dim of each transformer encode layer. + Default: [8, 8, 4, 4]. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: True. + use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. + Default: False. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + pretrained (str, optional): model pretrained path. Default: None. + convert_weights (bool): The flag indicates whether the + pre-trained model is from the original repo. We may need + to convert some keys to make it compatible. + Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='TruncNormal', std=.02, layer=['Linear']), + dict(type='Constant', val=1, layer=['LayerNorm']), + dict(type='Normal', std=0.01, layer=['Conv2d']) + ]`` + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 5, 8], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + paddings=[0, 0, 0, 0], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=True, + norm_after_stage=False, + use_conv_ffn=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + convert_weights=True, + init_cfg=[ + dict(type='TruncNormal', std=.02, layer=['Linear']), + dict(type='Constant', val=1, layer=['LayerNorm']), + dict(type='Kaiming', layer=['Conv2d']) + ]): + super().__init__(init_cfg=init_cfg) + + self.convert_weights = convert_weights + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + self.embed_dims = embed_dims + + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=paddings[i], + bias=True, + norm_cfg=norm_cfg) + + layers = ModuleList() + if use_abs_pos_embed: + pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1]) + pos_embed = AbsolutePositionEmbedding( + pos_shape=pos_shape, + pos_dim=embed_dims_i, + drop_rate=drop_rate) + layers.append(pos_embed) + layers.extend([ + PVTEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratios[i] * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + sr_ratio=sr_ratios[i], + use_conv_ffn=use_conv_ffn) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + if norm_after_stage: + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + else: + norm = nn.Identity() + self.layers.append(ModuleList([patch_embed, layers, norm])) + cur += num_layer + + def init_weights(self): + """Initialize the weights in backbone.""" + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + logger = get_root_logger() + state_dict = get_state_dict( + self.init_cfg['checkpoint'], map_location='cpu') + logger.warn(f'Load pre-trained model for ' + f'{self.__class__.__name__} from original repo') + + if self.convert_weights: + # Because pvt backbones are not supported by mmcls, + # so we need to convert pre-trained weights to match this + # implementation. + state_dict = pvt_convert(state_dict) + load_state_dict(self, state_dict, strict=False, logger=logger) + + else: + super(PyramidVisionTransformer, self).init_weights() + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, hw_shape = layer[0](x) + + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs + + +@MODELS.register_module() +class PyramidVisionTransformerV2(PyramidVisionTransformer): + """Implementation of `PVTv2: Improved Baselines with Pyramid Vision + Transformer `_.""" + + def __init__(self, **kwargs): + super(PyramidVisionTransformerV2, self).__init__( + patch_sizes=[7, 3, 3, 3], + paddings=[3, 1, 1, 1], + use_abs_pos_embed=False, + norm_after_stage=True, + use_conv_ffn=True, + **kwargs) diff --git a/mmpose/models/backbones/regnet.py b/mmpose/models/backbones/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..120523e658ecb2b3134eba45508ac47457a87f1d --- /dev/null +++ b/mmpose/models/backbones/regnet.py @@ -0,0 +1,331 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpose.registry import MODELS +from .resnet import ResNet +from .resnext import Bottleneck + + +@MODELS.register_module() +class RegNet(ResNet): + """RegNet backbone. + + More details can be found in `paper `__ . + + Args: + arch (dict): The parameter of RegNets. + - w0 (int): initial width + - wa (float): slope of width + - wm (float): quantization parameter to quantize the width + - depth (int): depth of the backbone + - group_w (int): width of group + - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck. + strides (Sequence[int]): Strides of the first block of each stage. + base_channels (int): Base channels after stem layer. + in_channels (int): Number of input image channels. Default: 3. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: "pytorch". + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Default: -1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import RegNet + >>> import torch + >>> self = RegNet( + arch=dict( + w0=88, + wa=26.31, + wm=2.25, + group_w=48, + depth=25, + bot_mul=1.0), + out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 96, 8, 8) + (1, 192, 4, 4) + (1, 432, 2, 2) + (1, 1008, 1, 1) + """ + arch_settings = { + 'regnetx_400mf': + dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0), + 'regnetx_800mf': + dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0), + 'regnetx_1.6gf': + dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0), + 'regnetx_3.2gf': + dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0), + 'regnetx_4.0gf': + dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0), + 'regnetx_6.4gf': + dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0), + 'regnetx_8.0gf': + dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0), + 'regnetx_12gf': + dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0), + } + + def __init__(self, + arch, + in_channels=3, + stem_channels=32, + base_channels=32, + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super(ResNet, self).__init__(init_cfg=init_cfg) + + # Generate RegNet parameters first + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the' \ + ' arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + widths, num_stages = self.generate_regnet( + arch['w0'], + arch['wa'], + arch['wm'], + arch['depth'], + ) + # Convert to per stage format + stage_widths, stage_blocks = self.get_stages_from_blocks(widths) + # Generate group widths and bot muls + group_widths = [arch['group_w'] for _ in range(num_stages)] + self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)] + # Adjust the compatibility of stage_widths and group_widths + stage_widths, group_widths = self.adjust_width_group( + stage_widths, self.bottleneck_ratio, group_widths) + + # Group params by stage + self.stage_widths = stage_widths + self.group_widths = group_widths + self.depth = sum(stage_blocks) + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert 1 <= num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + if self.deep_stem: + raise NotImplementedError( + 'deep_stem has not been implemented for RegNet') + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.stage_blocks = stage_blocks[:num_stages] + + self._make_stem_layer(in_channels, stem_channels) + + _in_channels = stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + group_width = self.group_widths[i] + width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i])) + stage_groups = width // group_width + + res_layer = self.make_res_layer( + block=Bottleneck, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=self.stage_widths[i], + expansion=1, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + base_channels=self.stage_widths[i], + groups=stage_groups, + width_per_group=group_width) + _in_channels = self.stage_widths[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = stage_widths[-1] + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + @staticmethod + def generate_regnet(initial_width, + width_slope, + width_parameter, + depth, + divisor=8): + """Generates per block width from RegNet parameters. + + Args: + initial_width ([int]): Initial width of the backbone + width_slope ([float]): Slope of the quantized linear function + width_parameter ([int]): Parameter used to quantize the width. + depth ([int]): Depth of the backbone. + divisor (int, optional): The divisor of channels. Defaults to 8. + + Returns: + list, int: return a list of widths of each stage and the number of + stages + """ + assert width_slope >= 0 + assert initial_width > 0 + assert width_parameter > 1 + assert initial_width % divisor == 0 + widths_cont = np.arange(depth) * width_slope + initial_width + ks = np.round( + np.log(widths_cont / initial_width) / np.log(width_parameter)) + widths = initial_width * np.power(width_parameter, ks) + widths = np.round(np.divide(widths, divisor)) * divisor + num_stages = len(np.unique(widths)) + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages + + @staticmethod + def quantize_float(number, divisor): + """Converts a float to closest non-zero int divisible by divior. + + Args: + number (int): Original number to be quantized. + divisor (int): Divisor used to quantize the number. + + Returns: + int: quantized number that is divisible by devisor. + """ + return int(round(number / divisor) * divisor) + + def adjust_width_group(self, widths, bottleneck_ratio, groups): + """Adjusts the compatibility of widths and groups. + + Args: + widths (list[int]): Width of each stage. + bottleneck_ratio (float): Bottleneck ratio. + groups (int): number of groups in each stage + + Returns: + tuple(list): The adjusted widths and groups of each stage. + """ + bottleneck_width = [ + int(w * b) for w, b in zip(widths, bottleneck_ratio) + ] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)] + bottleneck_width = [ + self.quantize_float(w_bot, g) + for w_bot, g in zip(bottleneck_width, groups) + ] + widths = [ + int(w_bot / b) + for w_bot, b in zip(bottleneck_width, bottleneck_ratio) + ] + return widths, groups + + def get_stages_from_blocks(self, widths): + """Gets widths/stage_blocks of network at each stage. + + Args: + widths (list[int]): Width in each stage. + + Returns: + tuple(list): width and depth of each stage + """ + width_diff = [ + width != width_prev + for width, width_prev in zip(widths + [0], [0] + widths) + ] + stage_widths = [ + width for width, diff in zip(widths, width_diff[:-1]) if diff + ] + stage_blocks = np.diff([ + depth for depth, diff in zip(range(len(width_diff)), width_diff) + if diff + ]).tolist() + return stage_widths, stage_blocks + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpose/models/backbones/resnest.py b/mmpose/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..b5eea8ad7e50c2ab997e2df17316943fcaf3a5fe --- /dev/null +++ b/mmpose/models/backbones/resnest.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(BaseModule): + """Split-Attention Conv2d. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + return getattr(self, self.norm0_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + groups=1, + width_per_group=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = SplitAttentionConv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Please refer to the `paper `__ + for details. + + Args: + depth (int): Network depth, from {50, 101, 152, 200}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)), + 269: (Bottleneck, (3, 30, 48, 8)) + } + + def __init__(self, + depth, + groups=1, + width_per_group=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.width_per_group = width_per_group + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super().__init__(depth=depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmpose/models/backbones/resnet.py b/mmpose/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a04853f60d179ee2450ca199b0a8c28ae893941f --- /dev/null +++ b/mmpose/models/backbones/resnet.py @@ -0,0 +1,715 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, constant_init +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class BasicBlock(BaseModule): + """BasicBlock for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the output channels of conv1. This is a + reserved argument in BasicBlock and should always be 1. Default: 1. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None. + style (str): `pytorch` or `caffe`. It is unused and reserved for + unified API with Bottleneck. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + expansion=1, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert self.expansion == 1 + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, out_channels, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + 3, + padding=1, + bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. Default: 4. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None. + style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the stride-two + layer is the first 1x1 conv layer. Default: "pytorch". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + expansion=4, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + assert style in ['pytorch', 'caffe'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: the normalization layer named "norm3" """ + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def get_expansion(block, expansion=None): + """Get the expansion of a residual block. + + The block expansion will be obtained by the following order: + + 1. If ``expansion`` is given, just return it. + 2. If ``block`` has the attribute ``expansion``, then return + ``block.expansion``. + 3. Return the default value according the the block type: + 1 for ``BasicBlock`` and 4 for ``Bottleneck``. + + Args: + block (class): The block class. + expansion (int | None): The given expansion ratio. + + Returns: + int: The expansion of the block. + """ + if isinstance(expansion, int): + assert expansion > 0 + elif expansion is None: + if hasattr(block, 'expansion'): + expansion = block.expansion + elif issubclass(block, BasicBlock): + expansion = 1 + elif issubclass(block, Bottleneck): + expansion = 4 + else: + raise TypeError(f'expansion is not specified for {block.__name__}') + else: + raise TypeError('expansion must be an integer or None') + + return expansion + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): Residual block used to build ResLayer. + num_blocks (int): Number of blocks. + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int, optional): The expansion for BasicBlock/Bottleneck. + If not specified, it will firstly be obtained via + ``block.expansion``. If the block has no attribute "expansion", + the following default values will be used: 1 for BasicBlock and + 4 for Bottleneck. Default: None. + stride (int): stride of the first block. Default: 1. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + """ + + def __init__(self, + block, + num_blocks, + in_channels, + out_channels, + expansion=None, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + downsample_first=True, + **kwargs): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + self.block = block + self.expansion = get_expansion(block, expansion) + + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + in_channels = out_channels + for _ in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + else: # downsample_first=False is for HourglassModule + for i in range(0, num_blocks - 1): + layers.append( + block( + in_channels=in_channels, + out_channels=in_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + + super().__init__(*layers) + + +@MODELS.register_module() +class ResNet(BaseBackbone): + """ResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import ResNet + >>> import torch + >>> self = ResNet(depth=18, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + expansion=None, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super(ResNet, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert 1 <= num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.expansion = get_expansion(self.block, expansion) + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + _in_channels = stem_channels + _out_channels = base_channels * self.expansion + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + res_layer = self.make_res_layer( + block=self.block, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=_out_channels, + expansion=self.expansion, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + _in_channels = _out_channels + _out_channels *= 2 + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = res_layer[-1].out_channels + + def make_res_layer(self, **kwargs): + """Make a ResLayer.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer.""" + if self.deep_stem: + self.stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone.""" + super(ResNet, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress zero_init_residual if use pretrained model. + return + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class ResNetV1d(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `__. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=True, **kwargs) diff --git a/mmpose/models/backbones/resnext.py b/mmpose/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..241f83a11449d3e816d4dbb16bd5715cf9ba6e3f --- /dev/null +++ b/mmpose/models/backbones/resnext.py @@ -0,0 +1,171 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpose.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super().__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpose/models/backbones/rsn.py b/mmpose/models/backbones/rsn.py new file mode 100644 index 0000000000000000000000000000000000000000..8267d23d952f9639dff524cfea8e8d111ce19584 --- /dev/null +++ b/mmpose/models/backbones/rsn.py @@ -0,0 +1,640 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy as cp + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, MaxPool2d +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class RSB(BaseModule): + """Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate + Local Representations for Multi-Person Pose Estimation" (ECCV 2020). + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + num_steps (int): Numbers of steps in RSB + stride (int): stride of the block. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + expand_times (int): Times by which the in_channels are expanded. + Default:26. + res_top_channels (int): Number of channels of feature output by + ResNet_top. Default:64. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + expansion = 1 + + def __init__(self, + in_channels, + out_channels, + num_steps=4, + stride=1, + downsample=None, + with_cp=False, + norm_cfg=dict(type='BN'), + expand_times=26, + res_top_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + assert num_steps > 1 + self.in_channels = in_channels + self.branch_channels = self.in_channels * expand_times + self.branch_channels //= res_top_channels + self.out_channels = out_channels + self.stride = stride + self.downsample = downsample + self.with_cp = with_cp + self.norm_cfg = norm_cfg + self.num_steps = num_steps + self.conv_bn_relu1 = ConvModule( + self.in_channels, + self.num_steps * self.branch_channels, + kernel_size=1, + stride=self.stride, + padding=0, + norm_cfg=self.norm_cfg, + inplace=False) + for i in range(self.num_steps): + for j in range(i + 1): + module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' + self.add_module( + module_name, + ConvModule( + self.branch_channels, + self.branch_channels, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg, + inplace=False)) + self.conv_bn3 = ConvModule( + self.num_steps * self.branch_channels, + self.out_channels * self.expansion, + kernel_size=1, + stride=1, + padding=0, + act_cfg=None, + norm_cfg=self.norm_cfg, + inplace=False) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + """Forward function.""" + + identity = x + x = self.conv_bn_relu1(x) + spx = torch.split(x, self.branch_channels, 1) + outputs = list() + outs = list() + for i in range(self.num_steps): + outputs_i = list() + outputs.append(outputs_i) + for j in range(i + 1): + if j == 0: + inputs = spx[i] + else: + inputs = outputs[i][j - 1] + if i > j: + inputs = inputs + outputs[i - 1][j] + module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' + module_i_j = getattr(self, module_name) + outputs[i].append(module_i_j(inputs)) + + outs.append(outputs[i][i]) + out = torch.cat(tuple(outs), 1) + out = self.conv_bn3(out) + + if self.downsample is not None: + identity = self.downsample(identity) + out = out + identity + + out = self.relu(out) + + return out + + +class Downsample_module(BaseModule): + """Downsample module for RSN. + + Args: + block (nn.Module): Downsample block. + num_blocks (list): Number of blocks in each downsample unit. + num_units (int): Numbers of downsample units. Default: 4 + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + num_steps (int): Number of steps in a block. Default:4 + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the input feature to + downsample module. Default: 64 + expand_times (int): Times by which the in_channels are expanded. + Default:26. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + block, + num_blocks, + num_steps=4, + num_units=4, + has_skip=False, + norm_cfg=dict(type='BN'), + in_channels=64, + expand_times=26, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.has_skip = has_skip + self.in_channels = in_channels + assert len(num_blocks) == num_units + self.num_blocks = num_blocks + self.num_units = num_units + self.num_steps = num_steps + self.norm_cfg = norm_cfg + self.layer1 = self._make_layer( + block, + in_channels, + num_blocks[0], + expand_times=expand_times, + res_top_channels=in_channels) + for i in range(1, num_units): + module_name = f'layer{i + 1}' + self.add_module( + module_name, + self._make_layer( + block, + in_channels * pow(2, i), + num_blocks[i], + stride=2, + expand_times=expand_times, + res_top_channels=in_channels)) + + def _make_layer(self, + block, + out_channels, + blocks, + stride=1, + expand_times=26, + res_top_channels=64): + downsample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = ConvModule( + self.in_channels, + out_channels * block.expansion, + kernel_size=1, + stride=stride, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + units = list() + units.append( + block( + self.in_channels, + out_channels, + num_steps=self.num_steps, + stride=stride, + downsample=downsample, + norm_cfg=self.norm_cfg, + expand_times=expand_times, + res_top_channels=res_top_channels)) + self.in_channels = out_channels * block.expansion + for _ in range(1, blocks): + units.append( + block( + self.in_channels, + out_channels, + num_steps=self.num_steps, + expand_times=expand_times, + res_top_channels=res_top_channels)) + + return nn.Sequential(*units) + + def forward(self, x, skip1, skip2): + out = list() + for i in range(self.num_units): + module_name = f'layer{i + 1}' + module_i = getattr(self, module_name) + x = module_i(x) + if self.has_skip: + x = x + skip1[i] + skip2[i] + out.append(x) + out.reverse() + + return tuple(out) + + +class Upsample_unit(BaseModule): + """Upsample unit for upsample module. + + Args: + ind (int): Indicates whether to interpolate (>0) and whether to + generate feature map for the next hourglass-like module. + num_units (int): Number of units that form a upsample module. Along + with ind and gen_cross_conv, nm_units is used to decide whether + to generate feature map for the next hourglass-like module. + in_channels (int): Channel number of the skip-in feature maps from + the corresponding downsample unit. + unit_channels (int): Channel number in this unit. Default:256. + gen_skip: (bool): Whether or not to generate skips for the posterior + downsample module. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (in): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + ind, + num_units, + in_channels, + unit_channels=256, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.num_units = num_units + self.norm_cfg = norm_cfg + self.in_skip = ConvModule( + in_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + self.relu = nn.ReLU(inplace=True) + + self.ind = ind + if self.ind > 0: + self.up_conv = ConvModule( + unit_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + act_cfg=None, + inplace=True) + + self.gen_skip = gen_skip + if self.gen_skip: + self.out_skip1 = ConvModule( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.out_skip2 = ConvModule( + unit_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + self.gen_cross_conv = gen_cross_conv + if self.ind == num_units - 1 and self.gen_cross_conv: + self.cross_conv = ConvModule( + unit_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=self.norm_cfg, + inplace=True) + + def forward(self, x, up_x): + out = self.in_skip(x) + + if self.ind > 0: + up_x = F.interpolate( + up_x, + size=(x.size(2), x.size(3)), + mode='bilinear', + align_corners=True) + up_x = self.up_conv(up_x) + out = out + up_x + out = self.relu(out) + + skip1 = None + skip2 = None + if self.gen_skip: + skip1 = self.out_skip1(x) + skip2 = self.out_skip2(out) + + cross_conv = None + if self.ind == self.num_units - 1 and self.gen_cross_conv: + cross_conv = self.cross_conv(out) + + return out, skip1, skip2, cross_conv + + +class Upsample_module(BaseModule): + """Upsample module for RSN. + + Args: + unit_channels (int): Channel number in the upsample units. + Default:256. + num_units (int): Numbers of upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + out_channels (int): Number of channels of feature output by upsample + module. Must equal to in_channels of downsample module. Default:64 + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + unit_channels=256, + num_units=4, + gen_skip=False, + gen_cross_conv=False, + norm_cfg=dict(type='BN'), + out_channels=64, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.in_channels = list() + for i in range(num_units): + self.in_channels.append(RSB.expansion * out_channels * pow(2, i)) + self.in_channels.reverse() + self.num_units = num_units + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.norm_cfg = norm_cfg + for i in range(num_units): + module_name = f'up{i + 1}' + self.add_module( + module_name, + Upsample_unit( + i, + self.num_units, + self.in_channels[i], + unit_channels, + self.gen_skip, + self.gen_cross_conv, + norm_cfg=self.norm_cfg, + out_channels=64)) + + def forward(self, x): + out = list() + skip1 = list() + skip2 = list() + cross_conv = None + for i in range(self.num_units): + module_i = getattr(self, f'up{i + 1}') + if i == 0: + outi, skip1_i, skip2_i, _ = module_i(x[i], None) + elif i == self.num_units - 1: + outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1]) + else: + outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1]) + out.append(outi) + skip1.append(skip1_i) + skip2.append(skip2_i) + skip1.reverse() + skip2.reverse() + + return out, skip1, skip2, cross_conv + + +class Single_stage_RSN(BaseModule): + """Single_stage Residual Steps Network. + + Args: + unit_channels (int): Channel number in the upsample units. Default:256. + num_units (int): Numbers of downsample/upsample units. Default: 4 + gen_skip (bool): Whether to generate skip for posterior downsample + module or not. Default:False + gen_cross_conv (bool): Whether to generate feature map for the next + hourglass-like module. Default:False + has_skip (bool): Have skip connections from prior upsample + module or not. Default:False + num_steps (int): Number of steps in RSB. Default: 4 + num_blocks (list): Number of blocks in each downsample unit. + Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks) + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + in_channels (int): Number of channels of the feature from ResNet_Top. + Default: 64. + expand_times (int): Times by which the in_channels are expanded in RSB. + Default:26. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + has_skip=False, + gen_skip=False, + gen_cross_conv=False, + unit_channels=256, + num_units=4, + num_steps=4, + num_blocks=[2, 2, 2, 2], + norm_cfg=dict(type='BN'), + in_channels=64, + expand_times=26, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__(init_cfg=init_cfg) + assert len(num_blocks) == num_units + self.has_skip = has_skip + self.gen_skip = gen_skip + self.gen_cross_conv = gen_cross_conv + self.num_units = num_units + self.num_steps = num_steps + self.unit_channels = unit_channels + self.num_blocks = num_blocks + self.norm_cfg = norm_cfg + + self.downsample = Downsample_module(RSB, num_blocks, num_steps, + num_units, has_skip, norm_cfg, + in_channels, expand_times) + self.upsample = Upsample_module(unit_channels, num_units, gen_skip, + gen_cross_conv, norm_cfg, in_channels) + + def forward(self, x, skip1, skip2): + mid = self.downsample(x, skip1, skip2) + out, skip1, skip2, cross_conv = self.upsample(mid) + + return out, skip1, skip2, cross_conv + + +class ResNet_top(BaseModule): + """ResNet top for RSN. + + Args: + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + channels (int): Number of channels of the feature output by ResNet_top. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, norm_cfg=dict(type='BN'), channels=64, init_cfg=None): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.top = nn.Sequential( + ConvModule( + 3, + channels, + kernel_size=7, + stride=2, + padding=3, + norm_cfg=norm_cfg, + inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1)) + + def forward(self, img): + return self.top(img) + + +@MODELS.register_module() +class RSN(BaseBackbone): + """Residual Steps Network backbone. Paper ref: Cai et al. "Learning + Delicate Local Representations for Multi-Person Pose Estimation" (ECCV + 2020). + + Args: + unit_channels (int): Number of Channels in an upsample unit. + Default: 256 + num_stages (int): Number of stages in a multi-stage RSN. Default: 4 + num_units (int): NUmber of downsample/upsample units in a single-stage + RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks) + num_blocks (list): Number of RSBs (Residual Steps Block) in each + downsample unit. Default: [2, 2, 2, 2] + num_steps (int): Number of steps in a RSB. Default:4 + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + res_top_channels (int): Number of channels of feature from ResNet_top. + Default: 64. + expand_times (int): Times by which the in_channels are expanded in RSB. + Default:26. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict( + type='Normal', + std=0.01, + layer=['Linear']), + ]`` + Example: + >>> from mmpose.models import RSN + >>> import torch + >>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2]) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... for feature in level_output: + ... print(tuple(feature.shape)) + ... + (1, 256, 64, 64) + (1, 256, 128, 128) + (1, 256, 64, 64) + (1, 256, 128, 128) + """ + + def __init__(self, + unit_channels=256, + num_stages=4, + num_units=4, + num_blocks=[2, 2, 2, 2], + num_steps=4, + norm_cfg=dict(type='BN'), + res_top_channels=64, + expand_times=26, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Normal', std=0.01, layer=['Linear']), + ]): + # Protect mutable default arguments + norm_cfg = cp.deepcopy(norm_cfg) + num_blocks = cp.deepcopy(num_blocks) + super().__init__(init_cfg=init_cfg) + self.unit_channels = unit_channels + self.num_stages = num_stages + self.num_units = num_units + self.num_blocks = num_blocks + self.num_steps = num_steps + self.norm_cfg = norm_cfg + + assert self.num_stages > 0 + assert self.num_steps > 1 + assert self.num_units > 1 + assert self.num_units == len(self.num_blocks) + self.top = ResNet_top(norm_cfg=norm_cfg) + self.multi_stage_rsn = nn.ModuleList([]) + for i in range(self.num_stages): + if i == 0: + has_skip = False + else: + has_skip = True + if i != self.num_stages - 1: + gen_skip = True + gen_cross_conv = True + else: + gen_skip = False + gen_cross_conv = False + self.multi_stage_rsn.append( + Single_stage_RSN(has_skip, gen_skip, gen_cross_conv, + unit_channels, num_units, num_steps, + num_blocks, norm_cfg, res_top_channels, + expand_times)) + + def forward(self, x): + """Model forward function.""" + out_feats = [] + skip1 = None + skip2 = None + x = self.top(x) + for i in range(self.num_stages): + out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2) + out_feats.append(out) + + return out_feats diff --git a/mmpose/models/backbones/scnet.py b/mmpose/models/backbones/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5c802d256e711aa70c955ac5bb91d2f7ff724604 --- /dev/null +++ b/mmpose/models/backbones/scnet.py @@ -0,0 +1,252 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .resnet import Bottleneck, ResNet + + +class SCConv(BaseModule): + """SCConv (Self-calibrated Convolution) + + Args: + in_channels (int): The input channels of the SCConv. + out_channels (int): The output channel of the SCConv. + stride (int): stride of SCConv. + pooling_r (int): size of pooling for scconv. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride, + pooling_r, + conv_cfg=None, + norm_cfg=dict(type='BN', momentum=0.1), + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + + assert in_channels == out_channels + + self.k2 = nn.Sequential( + nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), + build_conv_layer( + conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(norm_cfg, in_channels)[1], + ) + self.k3 = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(norm_cfg, in_channels)[1], + ) + self.k4 = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + in_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + nn.ReLU(inplace=True), + ) + + def forward(self, x): + """Forward function.""" + identity = x + + out = torch.sigmoid( + torch.add(identity, F.interpolate(self.k2(x), + identity.size()[2:]))) + out = torch.mul(self.k3(x), out) + out = self.k4(out) + + return out + + +class SCBottleneck(Bottleneck): + """SC(Self-calibrated) Bottleneck. + + Args: + in_channels (int): The input channels of the SCBottleneck block. + out_channels (int): The output channel of the SCBottleneck block. + """ + + pooling_r = 4 + + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.mid_channels = out_channels // self.expansion // 2 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=1, + bias=False) + self.add_module(self.norm1_name, norm1) + + self.k1 = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.stride, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, self.mid_channels)[1], + nn.ReLU(inplace=True)) + + self.conv2 = build_conv_layer( + self.conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=1, + bias=False) + self.add_module(self.norm2_name, norm2) + + self.scconv = SCConv(self.mid_channels, self.mid_channels, self.stride, + self.pooling_r, self.conv_cfg, self.norm_cfg) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels * 2, + out_channels, + kernel_size=1, + stride=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out_a = self.conv1(x) + out_a = self.norm1(out_a) + out_a = self.relu(out_a) + + out_a = self.k1(out_a) + + out_b = self.conv2(x) + out_b = self.norm2(out_b) + out_b = self.relu(out_b) + + out_b = self.scconv(out_b) + + out = self.conv3(torch.cat([out_a, out_b], dim=1)) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class SCNet(ResNet): + """SCNet backbone. + + Improving Convolutional Networks with Self-Calibrated Convolutions, + Jiang-Jiang Liu, Qibin Hou, Ming-Ming Cheng, Changhu Wang, Jiashi Feng, + IEEE CVPR, 2020. + http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf + + Args: + depth (int): Depth of scnet, from {50, 101}. + in_channels (int): Number of input image channels. Normally 3. + base_channels (int): Number of base channels of hidden layer. + num_stages (int): SCNet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmpose.models import SCNet + >>> import torch + >>> self = SCNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + + arch_settings = { + 50: (SCBottleneck, [3, 4, 6, 3]), + 101: (SCBottleneck, [3, 4, 23, 3]) + } + + def __init__(self, depth, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SCNet') + super().__init__(depth, **kwargs) diff --git a/mmpose/models/backbones/seresnet.py b/mmpose/models/backbones/seresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..617a1b72bee737ef0f3fb305e83ce33d8c8a7ea1 --- /dev/null +++ b/mmpose/models/backbones/seresnet.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.utils.checkpoint as cp + +from mmpose.registry import MODELS +from .resnet import Bottleneck, ResLayer, ResNet +from .utils.se_layer import SELayer + + +class SEBottleneck(Bottleneck): + """SEBottleneck block for SEResNet. + + Args: + in_channels (int): The input channels of the SEBottleneck block. + out_channels (int): The output channel of the SEBottleneck block. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + + def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.se_layer = SELayer(out_channels, ratio=se_ratio) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + out = self.se_layer(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class SEResNet(ResNet): + """SEResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import SEResNet + >>> import torch + >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, se_ratio=16, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SEResNet') + self.se_ratio = se_ratio + super().__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer(se_ratio=self.se_ratio, **kwargs) diff --git a/mmpose/models/backbones/seresnext.py b/mmpose/models/backbones/seresnext.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f5a6c8f3fe6b602aceb331781cd119958518b7 --- /dev/null +++ b/mmpose/models/backbones/seresnext.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpose.registry import MODELS +from .resnet import ResLayer +from .seresnet import SEBottleneck as _SEBottleneck +from .seresnet import SEResNet + + +class SEBottleneck(_SEBottleneck): + """SEBottleneck block for SEResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + base_channels (int): Middle channels of the first stage. Default: 64. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + se_ratio=16, + **kwargs): + super().__init__(in_channels, out_channels, se_ratio, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # We follow the same rational of ResNext to compute mid_channels. + # For SEResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for SEResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class SEResNeXt(SEResNet): + """SEResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import SEResNeXt + >>> import torch + >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super().__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpose/models/backbones/shufflenet_v1.py b/mmpose/models/backbones/shufflenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..17491910e9c1c2ec4eea04ca715dc91293f00cd4 --- /dev/null +++ b/mmpose/models/backbones/shufflenet_v1.py @@ -0,0 +1,338 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .utils import channel_shuffle, make_divisible + + +class ShuffleUnit(BaseModule): + """ShuffleUnit block. + + ShuffleNet unit with pointwise group convolution (GConv) and channel + shuffle. + + Args: + in_channels (int): The input channels of the ShuffleUnit. + out_channels (int): The output channels of the ShuffleUnit. + groups (int, optional): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3 + first_block (bool, optional): Whether it is the first ShuffleUnit of a + sequential ShuffleUnits. Default: True, which means not using the + grouped 1x1 convolution. + combine (str, optional): The ways to combine the input and output + branches. Default: 'add'. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + groups=3, + first_block=True, + combine='add', + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.first_block = first_block + self.combine = combine + self.groups = groups + self.bottleneck_channels = self.out_channels // 4 + self.with_cp = with_cp + + if self.combine == 'add': + self.depthwise_stride = 1 + self._combine_func = self._add + assert in_channels == out_channels, ( + 'in_channels must be equal to out_channels when combine ' + 'is add') + elif self.combine == 'concat': + self.depthwise_stride = 2 + self._combine_func = self._concat + self.out_channels -= self.in_channels + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + raise ValueError(f'Cannot combine tensors with {self.combine}. ' + 'Only "add" and "concat" are supported') + + self.first_1x1_groups = 1 if first_block else self.groups + self.g_conv_1x1_compress = ConvModule( + in_channels=self.in_channels, + out_channels=self.bottleneck_channels, + kernel_size=1, + groups=self.first_1x1_groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.depthwise_conv3x3_bn = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.bottleneck_channels, + kernel_size=3, + stride=self.depthwise_stride, + padding=1, + groups=self.bottleneck_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.g_conv_1x1_expand = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.out_channels, + kernel_size=1, + groups=self.groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.act = build_activation_layer(act_cfg) + + @staticmethod + def _add(x, out): + # residual connection + return x + out + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.g_conv_1x1_compress(x) + out = self.depthwise_conv3x3_bn(out) + + if self.groups > 1: + out = channel_shuffle(out, self.groups) + + out = self.g_conv_1x1_expand(out) + + if self.combine == 'concat': + residual = self.avgpool(residual) + out = self.act(out) + out = self._combine_func(residual, out) + else: + out = self._combine_func(residual, out) + out = self.act(out) + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV1(BaseBackbone): + """ShuffleNetV1 backbone. + + Args: + groups (int, optional): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3. + widen_factor (float, optional): Width multiplier - adjusts the number + of channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, ) + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.01, layer=['Conv2d']), + dict( + type='Constant', + val=1, + bias=0.0001 + layer=['_BatchNorm', 'GroupNorm']) + ]`` + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(2, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Normal', std=0.01, layer=['Conv2d']), + dict( + type='Constant', + val=1, + bias=0.0001, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__(init_cfg=init_cfg) + self.stage_blocks = [4, 8, 4] + self.groups = groups + + for index in out_indices: + if index not in range(0, 3): + raise ValueError('the item in out_indices must in ' + f'range(0, 3). But received {index}') + + if frozen_stages not in range(-1, 3): + raise ValueError('frozen_stages must be in range(-1, 3). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if groups == 1: + channels = (144, 288, 576) + elif groups == 2: + channels = (200, 400, 800) + elif groups == 3: + channels = (240, 480, 960) + elif groups == 4: + channels = (272, 544, 1088) + elif groups == 8: + channels = (384, 768, 1536) + else: + raise ValueError(f'{groups} groups is not supported for 1x1 ' + 'Grouped Convolutions') + + channels = [make_divisible(ch * widen_factor, 8) for ch in channels] + + self.in_channels = int(24 * widen_factor) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + first_block = (i == 0) + layer = self.make_layer(channels[i], num_blocks, first_block) + self.layers.append(layer) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + super(ShuffleNetV1, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d) and 'conv1' not in name: + nn.init.normal_(m.weight, mean=0, std=1.0 / m.weight.shape[1]) + + def make_layer(self, out_channels, num_blocks, first_block=False): + """Stack ShuffleUnit blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): Number of blocks. + first_block (bool, optional): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. Default: False, which means using + the grouped 1x1 convolution. + """ + layers = [] + for i in range(num_blocks): + first_block = first_block if i == 0 else False + combine_mode = 'concat' if i == 0 else 'add' + layers.append( + ShuffleUnit( + self.in_channels, + out_channels, + groups=self.groups, + first_block=first_block, + combine=combine_mode, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/shufflenet_v2.py b/mmpose/models/backbones/shufflenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9757841e73bf547fde77cf847a917c46acfb0b00 --- /dev/null +++ b/mmpose/models/backbones/shufflenet_v2.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .utils import channel_shuffle + + +class InvertedResidual(BaseModule): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__(init_cfg=init_cfg) + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV2(BaseBackbone): + """ShuffleNetV2 backbone. + + Args: + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.01, layer=['Conv2d']), + dict( + type='Constant', + val=1, + bias=0.0001 + layer=['_BatchNorm', 'GroupNorm']) + ]`` + """ + + def __init__(self, + widen_factor=1.0, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Normal', std=0.01, layer=['Conv2d']), + dict( + type='Constant', + val=1, + bias=0.0001, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__(init_cfg=init_cfg) + self.stage_blocks = [4, 8, 4] + for index in out_indices: + if index not in range(0, 4): + raise ValueError('the item in out_indices must in ' + f'range(0, 4). But received {index}') + + if frozen_stages not in range(-1, 4): + raise ValueError('frozen_stages must be in range(-1, 4). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if widen_factor == 0.5: + channels = [48, 96, 192, 1024] + elif widen_factor == 1.0: + channels = [116, 232, 464, 1024] + elif widen_factor == 1.5: + channels = [176, 352, 704, 1024] + elif widen_factor == 2.0: + channels = [244, 488, 976, 2048] + else: + raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. ' + f'But received {widen_factor}') + + self.in_channels = 24 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + layer = self._make_layer(channels[i], num_blocks) + self.layers.append(layer) + + output_channels = channels[-1] + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=output_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def _make_layer(self, out_channels, num_blocks): + """Stack blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): number of blocks. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append( + InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d) and 'conv1' not in name: + nn.init.normal_(m.weight, mean=0, std=1.0 / m.weight.shape[1]) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpose/models/backbones/swin.py b/mmpose/models/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f7c972787c19f64eb398615966722c5bdcd533 --- /dev/null +++ b/mmpose/models/backbones/swin.py @@ -0,0 +1,739 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from mmengine.runner import load_state_dict +from mmengine.utils import to_2tuple + +from mmpose.registry import MODELS +from mmpose.utils import get_root_logger +from ..utils.transformer import PatchEmbed, PatchMerging +from .base_backbone import BaseBackbone +from .utils import get_state_dict +from .utils.ckpt_convert import swin_converter + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SwinBlock, self).__init__(init_cfg=init_cfg) + + self.with_cp = with_cp + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (nn.Module | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = nn.ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@MODELS.register_module() +class SwinTransformer(BaseBackbone): + """ Swin Transformer + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/abs/2103.14030 + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LN'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + convert_weights (bool): The flag indicates whether the + pre-trained model is from the original repo. We may need + to convert some keys to make it compatible. + Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + Default: -1 (-1 means not freezing any parameters). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: ``[ + dict(type='TruncNormal', std=.02, layer=['Linear']), + dict(type='Constant', val=1, layer=['LayerNorm']), + ]`` + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + convert_weights=False, + frozen_stages=-1, + init_cfg=[ + dict(type='TruncNormal', std=.02, layer=['Linear']), + dict(type='Constant', val=1, layer=['LayerNorm']), + ]): + self.convert_weights = convert_weights + self.frozen_stages = frozen_stages + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + self.stages = nn.ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * in_channels, + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress zero_init_residual if use pretrained model. + logger = get_root_logger() + state_dict = get_state_dict( + self.init_cfg['checkpoint'], map_location='cpu') + if self.convert_weights: + # supported loading weight from original repo + state_dict = swin_converter(state_dict) + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + load_state_dict(self, state_dict, strict=False, logger=logger) + + else: + super(SwinTransformer, self).init_weights() + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/mmpose/models/backbones/tcn.py b/mmpose/models/backbones/tcn.py new file mode 100644 index 0000000000000000000000000000000000000000..ef49a1ff075288cc7a23f51f47c5b1bcdd383894 --- /dev/null +++ b/mmpose/models/backbones/tcn.py @@ -0,0 +1,284 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule, build_conv_layer +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from ..utils.regularizations import WeightNormClipHook +from .base_backbone import BaseBackbone + + +class BasicTemporalBlock(BaseModule): + """Basic block for VideoPose3D. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + mid_channels (int): The output channels of conv1. Default: 1024. + kernel_size (int): Size of the convolving kernel. Default: 3. + dilation (int): Spacing between kernel elements. Default: 3. + dropout (float): Dropout rate. Default: 0.25. + causal (bool): Use causal convolutions instead of symmetric + convolutions (for real-time applications). Default: False. + residual (bool): Use residual connection. Default: True. + use_stride_conv (bool): Use optimized TCN that designed + specifically for single-frame batching, i.e. where batches have + input length = receptive field, and output length = 1. This + implementation replaces dilated convolutions with strided + convolutions to avoid generating unused intermediate results. + Default: False. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: dict(type='Conv1d'). + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN1d'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels=1024, + kernel_size=3, + dilation=3, + dropout=0.25, + causal=False, + residual=True, + use_stride_conv=False, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + init_cfg=None): + # Protect mutable default arguments + conv_cfg = copy.deepcopy(conv_cfg) + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = mid_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.dropout = dropout + self.causal = causal + self.residual = residual + self.use_stride_conv = use_stride_conv + + self.pad = (kernel_size - 1) * dilation // 2 + if use_stride_conv: + self.stride = kernel_size + self.causal_shift = kernel_size // 2 if causal else 0 + self.dilation = 1 + else: + self.stride = 1 + self.causal_shift = kernel_size // 2 * dilation if causal else 0 + + self.conv1 = nn.Sequential( + ConvModule( + in_channels, + mid_channels, + kernel_size=kernel_size, + stride=self.stride, + dilation=self.dilation, + bias='auto', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + self.conv2 = nn.Sequential( + ConvModule( + mid_channels, + out_channels, + kernel_size=1, + bias='auto', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + if residual and in_channels != out_channels: + self.short_cut = build_conv_layer(conv_cfg, in_channels, + out_channels, 1) + else: + self.short_cut = None + + self.dropout = nn.Dropout(dropout) if dropout > 0 else None + + def forward(self, x): + """Forward function.""" + if self.use_stride_conv: + assert self.causal_shift + self.kernel_size // 2 < x.shape[2] + else: + assert 0 <= self.pad + self.causal_shift < x.shape[2] - \ + self.pad + self.causal_shift <= x.shape[2] + + out = self.conv1(x) + if self.dropout is not None: + out = self.dropout(out) + + out = self.conv2(out) + if self.dropout is not None: + out = self.dropout(out) + + if self.residual: + if self.use_stride_conv: + res = x[:, :, self.causal_shift + + self.kernel_size // 2::self.kernel_size] + else: + res = x[:, :, + (self.pad + self.causal_shift):(x.shape[2] - self.pad + + self.causal_shift)] + + if self.short_cut is not None: + res = self.short_cut(res) + out = out + res + + return out + + +@MODELS.register_module() +class TCN(BaseBackbone): + """TCN backbone. + + Temporal Convolutional Networks. + More details can be found in the + `paper `__ . + + Args: + in_channels (int): Number of input channels, which equals to + num_keypoints * num_features. + stem_channels (int): Number of feature channels. Default: 1024. + num_blocks (int): NUmber of basic temporal convolutional blocks. + Default: 2. + kernel_sizes (Sequence[int]): Sizes of the convolving kernel of + each basic block. Default: ``(3, 3, 3)``. + dropout (float): Dropout rate. Default: 0.25. + causal (bool): Use causal convolutions instead of symmetric + convolutions (for real-time applications). + Default: False. + residual (bool): Use residual connection. Default: True. + use_stride_conv (bool): Use TCN backbone optimized for + single-frame batching, i.e. where batches have input length = + receptive field, and output length = 1. This implementation + replaces dilated convolutions with strided convolutions to avoid + generating unused intermediate results. The weights are + interchangeable with the reference implementation. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: dict(type='Conv1d'). + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN1d'). + max_norm (float|None): if not None, the weight of convolution layers + will be clipped to have a maximum norm of max_norm. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict( + type='Kaiming', + mode='fan_in', + nonlinearity='relu', + layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + + Example: + >>> from mmpose.models import TCN + >>> import torch + >>> self = TCN(in_channels=34) + >>> self.eval() + >>> inputs = torch.rand(1, 34, 243) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 1024, 235) + (1, 1024, 217) + """ + + def __init__(self, + in_channels, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + causal=False, + residual=True, + use_stride_conv=False, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + max_norm=None, + init_cfg=[ + dict( + type='Kaiming', + mode='fan_in', + nonlinearity='relu', + layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + conv_cfg = copy.deepcopy(conv_cfg) + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__() + self.in_channels = in_channels + self.stem_channels = stem_channels + self.num_blocks = num_blocks + self.kernel_sizes = kernel_sizes + self.dropout = dropout + self.causal = causal + self.residual = residual + self.use_stride_conv = use_stride_conv + self.max_norm = max_norm + + assert num_blocks == len(kernel_sizes) - 1 + for ks in kernel_sizes: + assert ks % 2 == 1, 'Only odd filter widths are supported.' + + self.expand_conv = ConvModule( + in_channels, + stem_channels, + kernel_size=kernel_sizes[0], + stride=kernel_sizes[0] if use_stride_conv else 1, + bias='auto', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + dilation = kernel_sizes[0] + self.tcn_blocks = nn.ModuleList() + for i in range(1, num_blocks + 1): + self.tcn_blocks.append( + BasicTemporalBlock( + in_channels=stem_channels, + out_channels=stem_channels, + mid_channels=stem_channels, + kernel_size=kernel_sizes[i], + dilation=dilation, + dropout=dropout, + causal=causal, + residual=residual, + use_stride_conv=use_stride_conv, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + dilation *= kernel_sizes[i] + + if self.max_norm is not None: + # Apply weight norm clip to conv layers + weight_clip = WeightNormClipHook(self.max_norm) + for module in self.modules(): + if isinstance(module, nn.modules.conv._ConvNd): + weight_clip.register(module) + + self.dropout = nn.Dropout(dropout) if dropout > 0 else None + + def forward(self, x): + """Forward function.""" + x = self.expand_conv(x) + + if self.dropout is not None: + x = self.dropout(x) + + outs = [] + for i in range(self.num_blocks): + x = self.tcn_blocks[i](x) + outs.append(x) + + return tuple(outs) diff --git a/mmpose/models/backbones/utils/__init__.py b/mmpose/models/backbones/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07e42f89126c9e5663123794f92987b4f9b347f1 --- /dev/null +++ b/mmpose/models/backbones/utils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .channel_shuffle import channel_shuffle +from .inverted_residual import InvertedResidual +from .make_divisible import make_divisible +from .se_layer import SELayer +from .utils import get_state_dict, load_checkpoint + +__all__ = [ + 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', + 'load_checkpoint', 'get_state_dict' +] diff --git a/mmpose/models/backbones/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18849dde1a38fd3b8acbd319875d64b9a25873ba Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/__pycache__/channel_shuffle.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/channel_shuffle.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c81c5afae2980d7b9186af1c1a238b6f0dc87e7b Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/channel_shuffle.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/__pycache__/ckpt_convert.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/ckpt_convert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eabf9f401f7d7cc416f98fe636cb61d68c6f373 Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/ckpt_convert.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/__pycache__/inverted_residual.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/inverted_residual.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0591bcd4449bc76299c39a324cbd3ae9ec9c40 Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/inverted_residual.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/__pycache__/make_divisible.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/make_divisible.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..881b3b9f6688d5a9442faeb8df836a43cb2d2a0d Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/make_divisible.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/__pycache__/se_layer.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/se_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba43107fa90ef4e0822e649087196d6222e5da31 Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/se_layer.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/__pycache__/utils.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38312ac3b4e43d1c1271ec9039a0e384347f30a6 Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpose/models/backbones/utils/channel_shuffle.py b/mmpose/models/backbones/utils/channel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..aedd826bee690d42d92ed8a7f538b221e5b069e2 --- /dev/null +++ b/mmpose/models/backbones/utils/channel_shuffle.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def channel_shuffle(x, groups): + """Channel Shuffle operation. + + This function enables cross-group information flow for multiple groups + convolution layers. + + Args: + x (Tensor): The input tensor. + groups (int): The number of groups to divide the input tensor + in the channel dimension. + + Returns: + Tensor: The output tensor after channel shuffle operation. + """ + + batch_size, num_channels, height, width = x.size() + assert (num_channels % groups == 0), ('num_channels should be ' + 'divisible by groups') + channels_per_group = num_channels // groups + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(batch_size, groups * channels_per_group, height, width) + + return x diff --git a/mmpose/models/backbones/utils/ckpt_convert.py b/mmpose/models/backbones/utils/ckpt_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..14a43892c6630be31e915ed1f8b9164ba250e8bd --- /dev/null +++ b/mmpose/models/backbones/utils/ckpt_convert.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# This script consists of several convert functions which +# can modify the weights of model in original repo to be +# pre-trained weights. + +from collections import OrderedDict + + +def swin_converter(ckpt): + + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + else: + new_v = v + new_k = k + + new_ckpt['backbone.' + new_k] = new_v + + return new_ckpt diff --git a/mmpose/models/backbones/utils/inverted_residual.py b/mmpose/models/backbones/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..dff762c570550e4a738ae1833a4c82c18777115d --- /dev/null +++ b/mmpose/models/backbones/utils/inverted_residual.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """Inverted Residual Block. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + groups (None or int): The group number of the depthwise convolution. + Default: None, which means group number = mid_channels. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. + Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + groups=None, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + act_cfg = copy.deepcopy(act_cfg) + super().__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if groups is None: + groups = mid_channels + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if self.with_se: + self.se = SELayer(**se_cfg) + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmpose/models/backbones/utils/make_divisible.py b/mmpose/models/backbones/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..b7666be65939d5c76057e73927c230029cb1871d --- /dev/null +++ b/mmpose/models/backbones/utils/make_divisible.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float, optional): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmpose/models/backbones/utils/se_layer.py b/mmpose/models/backbones/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..ec6d7aeaa9a990dbaf437b4ff4f4ba685e008245 --- /dev/null +++ b/mmpose/models/backbones/utils/se_layer.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine +import torch.nn as nn +from mmcv.cnn import ConvModule + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='Sigmoid')) + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmengine.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/mmpose/models/backbones/utils/utils.py b/mmpose/models/backbones/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc4fe40cd481391edf73872e2d4f6eb35592779 --- /dev/null +++ b/mmpose/models/backbones/utils/utils.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +from mmengine.runner import CheckpointLoader, load_state_dict + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict_tmp = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict_tmp = checkpoint['model'] + else: + state_dict_tmp = checkpoint + + state_dict = OrderedDict() + # strip prefix of state_dict + for k, v in state_dict_tmp.items(): + if k.startswith('module.backbone.'): + state_dict[k[16:]] = v + elif k.startswith('module.'): + state_dict[k[7:]] = v + elif k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def get_state_dict(filename, map_location='cpu'): + """Get state_dict from a file or URI. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + map_location (str): Same as :func:`torch.load`. + + Returns: + OrderedDict: The state_dict. + """ + checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict_tmp = checkpoint['state_dict'] + else: + state_dict_tmp = checkpoint + + state_dict = OrderedDict() + # strip prefix of state_dict + for k, v in state_dict_tmp.items(): + if k.startswith('module.backbone.'): + state_dict[k[16:]] = v + elif k.startswith('module.'): + state_dict[k[7:]] = v + elif k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + return state_dict diff --git a/mmpose/models/backbones/v2v_net.py b/mmpose/models/backbones/v2v_net.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd1ab93b105b345aabc0ace2c7e776cd99e36a9 --- /dev/null +++ b/mmpose/models/backbones/v2v_net.py @@ -0,0 +1,275 @@ +# ------------------------------------------------------------------------------ +# Copyright and License Information +# Adapted from +# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models/v2v_net.py +# Original Licence: MIT License +# ------------------------------------------------------------------------------ + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class Basic3DBlock(BaseModule): + """A basic 3D convolutional block. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + kernel_size (int): Kernel size of the convolution operation + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: dict(type='Conv3d') + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN3d') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + conv_cfg=dict(type='Conv3d'), + norm_cfg=dict(type='BN3d'), + init_cfg=None): + super(Basic3DBlock, self).__init__(init_cfg=init_cfg) + self.block = ConvModule( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=((kernel_size - 1) // 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True) + + def forward(self, x): + """Forward function.""" + return self.block(x) + + +class Res3DBlock(BaseModule): + """A residual 3D convolutional block. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + kernel_size (int): Kernel size of the convolution operation + Default: 3 + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: dict(type='Conv3d') + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN3d') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + conv_cfg=dict(type='Conv3d'), + norm_cfg=dict(type='BN3d'), + init_cfg=None): + super(Res3DBlock, self).__init__(init_cfg=init_cfg) + self.res_branch = nn.Sequential( + ConvModule( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=((kernel_size - 1) // 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True), + ConvModule( + out_channels, + out_channels, + kernel_size, + stride=1, + padding=((kernel_size - 1) // 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + bias=True)) + + if in_channels == out_channels: + self.skip_con = nn.Sequential() + else: + self.skip_con = ConvModule( + in_channels, + out_channels, + 1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + bias=True) + + def forward(self, x): + """Forward function.""" + res = self.res_branch(x) + skip = self.skip_con(x) + return F.relu(res + skip, True) + + +class Pool3DBlock(BaseModule): + """A 3D max-pool block. + + Args: + pool_size (int): Pool size of the 3D max-pool layer + """ + + def __init__(self, pool_size): + super(Pool3DBlock, self).__init__() + self.pool_size = pool_size + + def forward(self, x): + """Forward function.""" + return F.max_pool3d( + x, kernel_size=self.pool_size, stride=self.pool_size) + + +class Upsample3DBlock(BaseModule): + """A 3D upsample block. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + kernel_size (int): Kernel size of the transposed convolution operation. + Default: 2 + stride (int): Kernel size of the transposed convolution operation. + Default: 2 + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=2, + init_cfg=None): + super(Upsample3DBlock, self).__init__(init_cfg=init_cfg) + assert kernel_size == 2 + assert stride == 2 + self.block = nn.Sequential( + nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + output_padding=0), nn.BatchNorm3d(out_channels), nn.ReLU(True)) + + def forward(self, x): + """Forward function.""" + return self.block(x) + + +class EncoderDecorder(BaseModule): + """An encoder-decoder block. + + Args: + in_channels (int): Input channels of this block + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, in_channels=32, init_cfg=None): + super(EncoderDecorder, self).__init__(init_cfg=init_cfg) + + self.encoder_pool1 = Pool3DBlock(2) + self.encoder_res1 = Res3DBlock(in_channels, in_channels * 2) + self.encoder_pool2 = Pool3DBlock(2) + self.encoder_res2 = Res3DBlock(in_channels * 2, in_channels * 4) + + self.mid_res = Res3DBlock(in_channels * 4, in_channels * 4) + + self.decoder_res2 = Res3DBlock(in_channels * 4, in_channels * 4) + self.decoder_upsample2 = Upsample3DBlock(in_channels * 4, + in_channels * 2, 2, 2) + self.decoder_res1 = Res3DBlock(in_channels * 2, in_channels * 2) + self.decoder_upsample1 = Upsample3DBlock(in_channels * 2, in_channels, + 2, 2) + + self.skip_res1 = Res3DBlock(in_channels, in_channels) + self.skip_res2 = Res3DBlock(in_channels * 2, in_channels * 2) + + def forward(self, x): + """Forward function.""" + skip_x1 = self.skip_res1(x) + x = self.encoder_pool1(x) + x = self.encoder_res1(x) + + skip_x2 = self.skip_res2(x) + x = self.encoder_pool2(x) + x = self.encoder_res2(x) + + x = self.mid_res(x) + + x = self.decoder_res2(x) + x = self.decoder_upsample2(x) + x = x + skip_x2 + + x = self.decoder_res1(x) + x = self.decoder_upsample1(x) + x = x + skip_x1 + + return x + + +@MODELS.register_module() +class V2VNet(BaseBackbone): + """V2VNet. + + Please refer to the `paper ` + for details. + + Args: + input_channels (int): + Number of channels of the input feature volume. + output_channels (int): + Number of channels of the output volume. + mid_channels (int): + Input and output channels of the encoder-decoder block. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: ``dict( + type='Normal', + std=0.001, + layer=['Conv3d', 'ConvTranspose3d'] + )`` + """ + + def __init__(self, + input_channels, + output_channels, + mid_channels=32, + init_cfg=dict( + type='Normal', + std=0.001, + layer=['Conv3d', 'ConvTranspose3d'])): + super(V2VNet, self).__init__(init_cfg=init_cfg) + + self.front_layers = nn.Sequential( + Basic3DBlock(input_channels, mid_channels // 2, 7), + Res3DBlock(mid_channels // 2, mid_channels), + ) + + self.encoder_decoder = EncoderDecorder(in_channels=mid_channels) + + self.output_layer = nn.Conv3d( + mid_channels, output_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + """Forward function.""" + x = self.front_layers(x) + x = self.encoder_decoder(x) + x = self.output_layer(x) + + return (x, ) diff --git a/mmpose/models/backbones/vgg.py b/mmpose/models/backbones/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa09d8dc7ded75678e8e23846474acee763a532 --- /dev/null +++ b/mmpose/models/backbones/vgg.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +def make_vgg_layer(in_channels, + out_channels, + num_blocks, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dilation=1, + with_norm=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layer = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers.append(layer) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +@MODELS.register_module() +class VGG(BaseBackbone): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_norm (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. When it is None, the default behavior depends on + whether num_classes is specified. If num_classes <= 0, the default + value is (4, ), outputting the last feature map before classifier. + If num_classes > 0, the default value is (5, ), outputting the + classification score. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. + with_last_pool (bool): Whether to keep the last pooling before + classifier. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict( + type='Normal', + std=0.01, + layer=['Linear']), + ]`` + """ + + # Parameters to build layers. Each element specifies the number of conv in + # each stage. For example, VGG11 contains 11 layers with learnable + # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, + # where 3 indicates the last three fully-connected layers. + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=False, + ceil_mode=False, + with_last_pool=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Normal', std=0.01, layer=['Linear']), + ]): + super().__init__(init_cfg=init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + + self.num_classes = num_classes + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + with_norm = norm_cfg is not None + + if out_indices is None: + out_indices = (5, ) if num_classes > 0 else (4, ) + assert max(out_indices) <= num_stages + self.out_indices = out_indices + + self.in_channels = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + out_channels = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.in_channels, + out_channels, + num_blocks, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dilation=dilation, + with_norm=with_norm, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.in_channels = out_channels + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + vgg_layers = getattr(self, self.module_name) + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + m = vgg_layers[j] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/vipnas_mbv3.py b/mmpose/models/backbones/vipnas_mbv3.py new file mode 100644 index 0000000000000000000000000000000000000000..9156cafa56d4f15766e48c77cd492e52345aed65 --- /dev/null +++ b/mmpose/models/backbones/vipnas_mbv3.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone +from .utils import InvertedResidual + + +@MODELS.register_module() +class ViPNAS_MobileNetV3(BaseBackbone): + """ViPNAS_MobileNetV3 backbone. + + "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search" + More details can be found in the `paper + `__ . + + Args: + wid (list(int)): Searched width config for each stage. + expan (list(int)): Searched expansion ratio config for each stage. + dep (list(int)): Searched depth config for each stage. + ks (list(int)): Searched kernel size config for each stage. + group (list(int)): Searched group number config for each stage. + att (list(bool)): Searched attention config for each stage. + stride (list(int)): Stride config for each stage. + act (list(dict)): Activation config for each stage. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + """ + + def __init__( + self, + wid=[16, 16, 24, 40, 80, 112, 160], + expan=[None, 1, 5, 4, 5, 5, 6], + dep=[None, 1, 4, 4, 4, 4, 4], + ks=[3, 3, 7, 7, 5, 7, 5], + group=[None, 8, 120, 20, 100, 280, 240], + att=[None, True, True, False, True, True, True], + stride=[2, 1, 2, 2, 2, 1, 2], + act=['HSwish', 'ReLU', 'ReLU', 'ReLU', 'HSwish', 'HSwish', 'HSwish'], + conv_cfg=None, + norm_cfg=dict(type='BN'), + frozen_stages=-1, + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ], + ): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + self.wid = wid + self.expan = expan + self.dep = dep + self.ks = ks + self.group = group + self.att = att + self.stride = stride + self.act = act + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.wid[0], + kernel_size=self.ks[0], + stride=self.stride[0], + padding=self.ks[0] // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=dict(type=self.act[0])) + + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + layer_index = 0 + for i, dep in enumerate(self.dep[1:]): + mid_channels = self.wid[i + 1] * self.expan[i + 1] + + if self.att[i + 1]: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=1.0, divisor=2.0))) + else: + se_cfg = None + + if self.expan[i + 1] == 1: + with_expand_conv = False + else: + with_expand_conv = True + + for j in range(dep): + if j == 0: + stride = self.stride[i + 1] + in_channels = self.wid[i] + else: + stride = 1 + in_channels = self.wid[i + 1] + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=self.wid[i + 1], + mid_channels=mid_channels, + kernel_size=self.ks[i + 1], + groups=self.group[i + 1], + stride=stride, + se_cfg=se_cfg, + with_expand_conv=with_expand_conv, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=self.act[i + 1]), + with_cp=self.with_cp) + layer_index += 1 + layer_name = f'layer{layer_index}' + self.add_module(layer_name, layer) + layers.append(layer_name) + return layers + + def forward(self, x): + x = self.conv1(x) + + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + + return (x, ) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/backbones/vipnas_resnet.py b/mmpose/models/backbones/vipnas_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7be810b449c1a840c425c69e3d1d1340583e52ea --- /dev/null +++ b/mmpose/models/backbones/vipnas_resnet.py @@ -0,0 +1,596 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmcv.cnn.bricks import ContextBlock +from mmengine.model import BaseModule, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class ViPNAS_Bottleneck(BaseModule): + """Bottleneck block for ViPNAS_ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. Default: 4. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module): downsample operation on identity branch. + Default: None. + style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the stride-two + layer is the first 1x1 conv layer. Default: "pytorch". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + kernel_size (int): kernel size of conv2 searched in ViPANS. + groups (int): group number of conv2 searched in ViPNAS. + attention (bool): whether to use attention module in the end of + the block. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + expansion=4, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + kernel_size=3, + groups=1, + attention=False, + init_cfg=None): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + assert style in ['pytorch', 'caffe'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=kernel_size, + stride=self.conv2_stride, + padding=kernel_size // 2, + groups=groups, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + if attention: + self.attention = ContextBlock(out_channels, + max(1.0 / 16, 16.0 / out_channels)) + else: + self.attention = None + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: the normalization layer named "norm3" """ + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.attention is not None: + out = self.attention(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def get_expansion(block, expansion=None): + """Get the expansion of a residual block. + + The block expansion will be obtained by the following order: + + 1. If ``expansion`` is given, just return it. + 2. If ``block`` has the attribute ``expansion``, then return + ``block.expansion``. + 3. Return the default value according the the block type: + 4 for ``ViPNAS_Bottleneck``. + + Args: + block (class): The block class. + expansion (int | None): The given expansion ratio. + + Returns: + int: The expansion of the block. + """ + if isinstance(expansion, int): + assert expansion > 0 + elif expansion is None: + if hasattr(block, 'expansion'): + expansion = block.expansion + elif issubclass(block, ViPNAS_Bottleneck): + expansion = 1 + else: + raise TypeError(f'expansion is not specified for {block.__name__}') + else: + raise TypeError('expansion must be an integer or None') + + return expansion + + +class ViPNAS_ResLayer(Sequential): + """ViPNAS_ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): Residual block used to build ViPNAS ResLayer. + num_blocks (int): Number of blocks. + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int, optional): The expansion for BasicBlock/Bottleneck. + If not specified, it will firstly be obtained via + ``block.expansion``. If the block has no attribute "expansion", + the following default values will be used: 1 for BasicBlock and + 4 for Bottleneck. Default: None. + stride (int): stride of the first block. Default: 1. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + kernel_size (int): Kernel Size of the corresponding convolution layer + searched in the block. + groups (int): Group number of the corresponding convolution layer + searched in the block. + attention (bool): Whether to use attention module in the end of the + block. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + block, + num_blocks, + in_channels, + out_channels, + expansion=None, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + downsample_first=True, + kernel_size=3, + groups=1, + attention=False, + init_cfg=None, + **kwargs): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + self.block = block + self.expansion = get_expansion(block, expansion) + + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + in_channels = out_channels + for _ in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + else: # downsample_first=False is for HourglassModule + for i in range(0, num_blocks - 1): + layers.append( + block( + in_channels=in_channels, + out_channels=in_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=kernel_size, + groups=groups, + attention=attention, + **kwargs)) + + super().__init__(*layers, init_cfg=init_cfg) + + +@MODELS.register_module() +class ViPNAS_ResNet(BaseBackbone): + """ViPNAS_ResNet backbone. + + "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search" + More details can be found in the `paper + `__ . + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + wid (list(int)): Searched width config for each stage. + expan (list(int)): Searched expansion ratio config for each stage. + dep (list(int)): Searched depth config for each stage. + ks (list(int)): Searched kernel size config for each stage. + group (list(int)): Searched group number config for each stage. + att (list(bool)): Searched attention config for each stage. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: + ``[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]`` + """ + + arch_settings = { + 50: ViPNAS_Bottleneck, + } + + def __init__(self, + depth, + in_channels=3, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + wid=[48, 80, 160, 304, 608], + expan=[None, 1, 1, 1, 1], + dep=[None, 4, 6, 7, 3], + ks=[7, 3, 5, 5, 5], + group=[None, 16, 16, 16, 16], + att=[None, True, False, True, True], + init_cfg=[ + dict(type='Normal', std=0.001, layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + super().__init__(init_cfg=init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = dep[0] + self.num_stages = num_stages + assert 1 <= num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block = self.arch_settings[depth] + self.stage_blocks = dep[1:1 + num_stages] + + self._make_stem_layer(in_channels, wid[0], ks[0]) + + self.res_layers = [] + _in_channels = wid[0] + for i, num_blocks in enumerate(self.stage_blocks): + expansion = get_expansion(self.block, expan[i + 1]) + _out_channels = wid[i + 1] * expansion + stride = strides[i] + dilation = dilations[i] + res_layer = self.make_res_layer( + block=self.block, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=_out_channels, + expansion=expansion, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + kernel_size=ks[i + 1], + groups=group[i + 1], + attention=att[i + 1]) + _in_channels = _out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = res_layer[-1].out_channels + + def make_res_layer(self, **kwargs): + """Make a ViPNAS ResLayer.""" + return ViPNAS_ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels, kernel_size): + """Make stem layer.""" + if self.deep_stem: + self.stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/builder.py b/mmpose/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..cefaedc29100bcbc4c5b9cde55db8f66b25ab637 --- /dev/null +++ b/mmpose/models/builder.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmpose.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +POSE_ESTIMATORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_pose_estimator(cfg): + """Build pose estimator.""" + return POSE_ESTIMATORS.build(cfg) + + +def build_posenet(cfg): + """Build posenet.""" + warnings.warn( + '``build_posenet`` will be deprecated soon, ' + 'please use ``build_pose_estimator`` instead.', DeprecationWarning) + return build_pose_estimator(cfg) diff --git a/mmpose/models/data_preprocessors/__init__.py b/mmpose/models/data_preprocessors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9bd22e2b20be84a17d05ab3058efd8d934f261 --- /dev/null +++ b/mmpose/models/data_preprocessors/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_preprocessor import PoseDataPreprocessor + +__all__ = ['PoseDataPreprocessor'] diff --git a/mmpose/models/data_preprocessors/__pycache__/__init__.cpython-38.pyc b/mmpose/models/data_preprocessors/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a2f2ab54cd6fe7d1b1c2e81d614ae7169be517b Binary files /dev/null and b/mmpose/models/data_preprocessors/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/data_preprocessors/__pycache__/data_preprocessor.cpython-38.pyc b/mmpose/models/data_preprocessors/__pycache__/data_preprocessor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e73f1ee78878c0e38595df4abf05639bbb30a06 Binary files /dev/null and b/mmpose/models/data_preprocessors/__pycache__/data_preprocessor.cpython-38.pyc differ diff --git a/mmpose/models/data_preprocessors/data_preprocessor.py b/mmpose/models/data_preprocessors/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfe54ab59108cbd033b9186b8b6aae0144677d9 --- /dev/null +++ b/mmpose/models/data_preprocessors/data_preprocessor.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import ImgDataPreprocessor + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class PoseDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for pose estimation tasks.""" diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4d988a5f85bad9bb60728b4a020ba75970a8a8 --- /dev/null +++ b/mmpose/models/heads/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_head import BaseHead +from .coord_cls_heads import RTMCCHead, SimCCHead +from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead, + HeatmapHead, MSPNHead, ViPNASHead) +from .hybrid_heads import DEKRHead +from .regression_heads import (DSNTHead, IntegralRegressionHead, + RegressionHead, RLEHead) + +__all__ = [ + 'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead', + 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead', + 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead' +] diff --git a/mmpose/models/heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5df3e265f8c637082b69aeb8bf99bb78df0080d Binary files /dev/null and b/mmpose/models/heads/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/heads/__pycache__/base_head.cpython-38.pyc b/mmpose/models/heads/__pycache__/base_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e25d444bc1607e10e32c56ce604e9303eb621bb Binary files /dev/null and b/mmpose/models/heads/__pycache__/base_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/base_head.py b/mmpose/models/heads/base_head.py new file mode 100644 index 0000000000000000000000000000000000000000..14882db2439d1ff334f066939f5b0cb082b4d0ea --- /dev/null +++ b/mmpose/models/heads/base_head.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Tuple, Union + +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (Features, InstanceList, OptConfigType, + OptSampleList, Predictions) + + +class BaseHead(BaseModule, metaclass=ABCMeta): + """Base head. A subclass should override :meth:`predict` and :meth:`loss`. + + Args: + init_cfg (dict, optional): The extra init config of layers. + Defaults to None. + """ + + @abstractmethod + def forward(self, feats: Tuple[Tensor]): + """Forward the network.""" + + @abstractmethod + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}) -> Predictions: + """Predict results from features.""" + + @abstractmethod + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + def decode(self, batch_outputs: Union[Tensor, + Tuple[Tensor]]) -> InstanceList: + """Decode keypoints from outputs. + + Args: + batch_outputs (Tensor | Tuple[Tensor]): The network outputs of + a data batch + + Returns: + List[InstanceData]: A list of InstanceData, each contains the + decoded pose information of the instances of one data sample. + """ + + def _pack_and_call(args, func): + if not isinstance(args, tuple): + args = (args, ) + return func(*args) + + if self.decoder is None: + raise RuntimeError( + f'The decoder has not been set in {self.__class__.__name__}. ' + 'Please set the decoder configs in the init parameters to ' + 'enable head methods `head.predict()` and `head.decode()`') + + if self.decoder.support_batch_decoding: + batch_keypoints, batch_scores = _pack_and_call( + batch_outputs, self.decoder.batch_decode) + + else: + batch_output_np = to_numpy(batch_outputs, unzip=True) + batch_keypoints = [] + batch_scores = [] + for outputs in batch_output_np: + keypoints, scores = _pack_and_call(outputs, + self.decoder.decode) + batch_keypoints.append(keypoints) + batch_scores.append(scores) + + preds = [ + InstanceData(keypoints=keypoints, keypoint_scores=scores) + for keypoints, scores in zip(batch_keypoints, batch_scores) + ] + + return preds diff --git a/mmpose/models/heads/coord_cls_heads/__init__.py b/mmpose/models/heads/coord_cls_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..104ff9130898956af54de6d243fcd4a16167f38b --- /dev/null +++ b/mmpose/models/heads/coord_cls_heads/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .rtmcc_head import RTMCCHead +from .simcc_head import SimCCHead + +__all__ = ['SimCCHead', 'RTMCCHead'] diff --git a/mmpose/models/heads/coord_cls_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/coord_cls_heads/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5541e59c8e006ab9220cb2ab1ecd200490d94bd Binary files /dev/null and b/mmpose/models/heads/coord_cls_heads/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/heads/coord_cls_heads/__pycache__/rtmcc_head.cpython-38.pyc b/mmpose/models/heads/coord_cls_heads/__pycache__/rtmcc_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fff41e361acf3e2b09e6336bb4904b581f535a6 Binary files /dev/null and b/mmpose/models/heads/coord_cls_heads/__pycache__/rtmcc_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/coord_cls_heads/__pycache__/simcc_head.cpython-38.pyc b/mmpose/models/heads/coord_cls_heads/__pycache__/simcc_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa9bd2c9a58693c71a67b51d37e27a09f428f00 Binary files /dev/null and b/mmpose/models/heads/coord_cls_heads/__pycache__/simcc_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/coord_cls_heads/rtmcc_head.py b/mmpose/models/heads/coord_cls_heads/rtmcc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..94d613192c089e2beca14afb832ee4370f8706cf --- /dev/null +++ b/mmpose/models/heads/coord_cls_heads/rtmcc_head.py @@ -0,0 +1,303 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Optional, Sequence, Tuple, Union + +import torch +from mmengine.dist import get_dist_info +from mmengine.structures import PixelData +from torch import Tensor, nn + +from mmpose.codecs.utils import get_simcc_normalized +from mmpose.evaluation.functional import simcc_pck_accuracy +from mmpose.models.utils.rtmcc_block import RTMCCBlock, ScaleNorm +from mmpose.models.utils.tta import flip_vectors +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + OptSampleList) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class RTMCCHead(BaseHead): + """Top-down head introduced in RTMPose (2023). The head is composed of a + large-kernel convolutional layer, a fully-connected layer and a Gated + Attention Unit to generate 1d representation from low-resolution feature + maps. + + Args: + in_channels (int | sequence[int]): Number of channels in the input + feature map. + out_channels (int): Number of channels in the output heatmap. + input_size (tuple): Size of input image in shape [w, h]. + in_featuremap_size (int | sequence[int]): Size of input feature map. + simcc_split_ratio (float): Split ratio of pixels. + Default: 2.0. + final_layer_kernel_size (int): Kernel size of the convolutional layer. + Default: 1. + gau_cfg (Config): Config dict for the Gated Attention Unit. + Default: dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='ReLU', + use_rel_bias=False, + pos_enc=False). + loss (Config): Config of the keypoint loss. Defaults to use + :class:`KLDiscretLoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + """ + + def __init__( + self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + input_size: Tuple[int, int], + in_featuremap_size: Tuple[int, int], + simcc_split_ratio: float = 2.0, + final_layer_kernel_size: int = 1, + gau_cfg: ConfigType = dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='ReLU', + use_rel_bias=False, + pos_enc=False), + loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None, + ): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.input_size = input_size + self.in_featuremap_size = in_featuremap_size + self.simcc_split_ratio = simcc_split_ratio + + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + if isinstance(in_channels, (tuple, list)): + raise ValueError( + f'{self.__class__.__name__} does not support selecting ' + 'multiple input features.') + + # Define SimCC layers + flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1] + + self.final_layer = nn.Conv2d( + in_channels, + out_channels, + kernel_size=final_layer_kernel_size, + stride=1, + padding=final_layer_kernel_size // 2) + self.mlp = nn.Sequential( + ScaleNorm(flatten_dims), + nn.Linear(flatten_dims, gau_cfg['hidden_dims'], bias=False)) + + W = int(self.input_size[0] * self.simcc_split_ratio) + H = int(self.input_size[1] * self.simcc_split_ratio) + + self.gau = RTMCCBlock( + self.out_channels, + gau_cfg['hidden_dims'], + gau_cfg['hidden_dims'], + s=gau_cfg['s'], + expansion_factor=gau_cfg['expansion_factor'], + dropout_rate=gau_cfg['dropout_rate'], + drop_path=gau_cfg['drop_path'], + attn_type='self-attn', + act_fn=gau_cfg['act_fn'], + use_rel_bias=gau_cfg['use_rel_bias'], + pos_enc=gau_cfg['pos_enc']) + + self.cls_x = nn.Linear(gau_cfg['hidden_dims'], W, bias=False) + self.cls_y = nn.Linear(gau_cfg['hidden_dims'], H, bias=False) + + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: + """Forward the network. + + The input is multi scale feature maps and the + output is the heatmap. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + pred_x (Tensor): 1d representation of x. + pred_y (Tensor): 1d representation of y. + """ + feats = feats[-1] + + feats = self.final_layer(feats) # -> B, K, H, W + + # flatten the output heatmap + feats = torch.flatten(feats, 2) + + feats = self.mlp(feats) # -> B, K, hidden + + feats = self.gau(feats) + + pred_x = self.cls_x(feats) + pred_y = self.cls_y(feats) + + return pred_x, pred_y + + def predict( + self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}, + ) -> InstanceList: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + List[InstanceData]: The pose predictions, each contains + the following fields: + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + - keypoint_x_labels (np.ndarray, optional): The predicted 1-D + intensity distribution in the x direction + - keypoint_y_labels (np.ndarray, optional): The predicted 1-D + intensity distribution in the y direction + """ + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + + _batch_pred_x, _batch_pred_y = self.forward(_feats) + + _batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip) + _batch_pred_x_flip, _batch_pred_y_flip = flip_vectors( + _batch_pred_x_flip, + _batch_pred_y_flip, + flip_indices=flip_indices) + + batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5 + batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5 + else: + batch_pred_x, batch_pred_y = self.forward(feats) + + preds = self.decode((batch_pred_x, batch_pred_y)) + + if test_cfg.get('output_heatmaps', False): + rank, _ = get_dist_info() + if rank == 0: + warnings.warn('The predicted simcc values are normalized for ' + 'visualization. This may cause discrepancy ' + 'between the keypoint scores and the 1D heatmaps' + '.') + + # normalize the predicted 1d distribution + batch_pred_x = get_simcc_normalized(batch_pred_x) + batch_pred_y = get_simcc_normalized(batch_pred_y) + + B, K, _ = batch_pred_x.shape + # B, K, Wx -> B, K, Wx, 1 + x = batch_pred_x.reshape(B, K, 1, -1) + # B, K, Wy -> B, K, 1, Wy + y = batch_pred_y.reshape(B, K, -1, 1) + # B, K, Wx, Wy + batch_heatmaps = torch.matmul(y, x) + pred_fields = [ + PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() + ] + + for pred_instances, pred_x, pred_y in zip(preds, + to_numpy(batch_pred_x), + to_numpy(batch_pred_y)): + + pred_instances.keypoint_x_labels = pred_x[None] + pred_instances.keypoint_y_labels = pred_y[None] + + return preds, pred_fields + else: + return preds + + def loss( + self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}, + ) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_x, pred_y = self.forward(feats) + + gt_x = torch.cat([ + d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples + ], + dim=0) + gt_y = torch.cat([ + d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples + ], + dim=0) + keypoint_weights = torch.cat( + [ + d.gt_instance_labels.keypoint_weights + for d in batch_data_samples + ], + dim=0, + ) + + pred_simcc = (pred_x, pred_y) + gt_simcc = (gt_x, gt_y) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights) + + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = simcc_pck_accuracy( + output=to_numpy(pred_simcc), + target=to_numpy(gt_simcc), + simcc_split_ratio=self.simcc_split_ratio, + mask=to_numpy(keypoint_weights) > 0, + ) + + acc_pose = torch.tensor(avg_acc, device=gt_x.device) + losses.update(acc_pose=acc_pose) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [ + dict(type='Normal', layer=['Conv2d'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1), + dict(type='Normal', layer=['Linear'], std=0.01, bias=0), + ] + return init_cfg diff --git a/mmpose/models/heads/coord_cls_heads/simcc_head.py b/mmpose/models/heads/coord_cls_heads/simcc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b9287b7204ac1ca541be62aaf9e43a7d4f472210 --- /dev/null +++ b/mmpose/models/heads/coord_cls_heads/simcc_head.py @@ -0,0 +1,369 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Optional, Sequence, Tuple, Union + +import torch +from mmcv.cnn import build_conv_layer +from mmengine.dist import get_dist_info +from mmengine.structures import PixelData +from torch import Tensor, nn + +from mmpose.codecs.utils import get_simcc_normalized +from mmpose.evaluation.functional import simcc_pck_accuracy +from mmpose.models.utils.tta import flip_vectors +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + OptSampleList) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class SimCCHead(BaseHead): + """Top-down heatmap head introduced in `SimCC`_ by Li et al (2022). The + head is composed of a few deconvolutional layers followed by a fully- + connected layer to generate 1d representation from low-resolution feature + maps. + + Args: + in_channels (int | sequence[int]): Number of channels in the input + feature map + out_channels (int): Number of channels in the output heatmap + input_size (tuple): Input image size in shape [w, h] + in_featuremap_size (int | sequence[int]): Size of input feature map + simcc_split_ratio (float): Split ratio of pixels + deconv_type (str, optional): The type of deconv head which should + be one of the following options: + + - ``'heatmap'``: make deconv layers in `HeatmapHead` + - ``'vipnas'``: make deconv layers in `ViPNASHead` + + Defaults to ``'Heatmap'`` + deconv_out_channels (sequence[int]): The output channel number of each + deconv layer. Defaults to ``(256, 256, 256)`` + deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size + of each deconv layer. Each element should be either an integer for + both height and width dimensions, or a tuple of two integers for + the height and the width dimension respectively.Defaults to + ``(4, 4, 4)`` + deconv_num_groups (Sequence[int], optional): The group number of each + deconv layer. Defaults to ``(16, 16, 16)`` + conv_out_channels (sequence[int], optional): The output channel number + of each intermediate conv layer. ``None`` means no intermediate + conv layer between deconv layers and the final conv layer. + Defaults to ``None`` + conv_kernel_sizes (sequence[int | tuple], optional): The kernel size + of each intermediate conv layer. Defaults to ``None`` + final_layer (dict): Arguments of the final Conv2d layer. + Defaults to ``dict(kernel_size=1)`` + loss (Config): Config of the keypoint loss. Defaults to use + :class:`KLDiscretLoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`SimCC`: https://arxiv.org/abs/2107.03332 + """ + + _version = 2 + + def __init__( + self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + input_size: Tuple[int, int], + in_featuremap_size: Tuple[int, int], + simcc_split_ratio: float = 2.0, + deconv_type: str = 'heatmap', + deconv_out_channels: OptIntSeq = (256, 256, 256), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + deconv_num_groups: OptIntSeq = (16, 16, 16), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None, + ): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + if deconv_type not in {'heatmap', 'vipnas'}: + raise ValueError( + f'{self.__class__.__name__} got invalid `deconv_type` value' + f'{deconv_type}. Should be one of ' + '{"heatmap", "vipnas"}') + + self.in_channels = in_channels + self.out_channels = out_channels + self.input_size = input_size + self.in_featuremap_size = in_featuremap_size + self.simcc_split_ratio = simcc_split_ratio + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + num_deconv = len(deconv_out_channels) if deconv_out_channels else 0 + if num_deconv != 0: + self.heatmap_size = tuple( + [s * (2**num_deconv) for s in in_featuremap_size]) + + # deconv layers + 1x1 conv + self.deconv_head = self._make_deconv_head( + in_channels=in_channels, + out_channels=out_channels, + deconv_type=deconv_type, + deconv_out_channels=deconv_out_channels, + deconv_kernel_sizes=deconv_kernel_sizes, + deconv_num_groups=deconv_num_groups, + conv_out_channels=conv_out_channels, + conv_kernel_sizes=conv_kernel_sizes, + final_layer=final_layer) + + if final_layer is not None: + in_channels = out_channels + else: + in_channels = deconv_out_channels[-1] + + else: + self.deconv_head = None + + if final_layer is not None: + cfg = dict( + type='Conv2d', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1) + cfg.update(final_layer) + self.final_layer = build_conv_layer(cfg) + else: + self.final_layer = None + + self.heatmap_size = in_featuremap_size + + # Define SimCC layers + flatten_dims = self.heatmap_size[0] * self.heatmap_size[1] + + W = int(self.input_size[0] * self.simcc_split_ratio) + H = int(self.input_size[1] * self.simcc_split_ratio) + + self.mlp_head_x = nn.Linear(flatten_dims, W) + self.mlp_head_y = nn.Linear(flatten_dims, H) + + def _make_deconv_head( + self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + deconv_type: str = 'heatmap', + deconv_out_channels: OptIntSeq = (256, 256, 256), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + deconv_num_groups: OptIntSeq = (16, 16, 16), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1) + ) -> nn.Module: + """Create deconvolutional layers by given parameters.""" + + if deconv_type == 'heatmap': + deconv_head = MODELS.build( + dict( + type='HeatmapHead', + in_channels=self.in_channels, + out_channels=out_channels, + deconv_out_channels=deconv_out_channels, + deconv_kernel_sizes=deconv_kernel_sizes, + conv_out_channels=conv_out_channels, + conv_kernel_sizes=conv_kernel_sizes, + final_layer=final_layer)) + else: + deconv_head = MODELS.build( + dict( + type='ViPNASHead', + in_channels=in_channels, + out_channels=out_channels, + deconv_out_channels=deconv_out_channels, + deconv_num_groups=deconv_num_groups, + conv_out_channels=conv_out_channels, + conv_kernel_sizes=conv_kernel_sizes, + final_layer=final_layer)) + + return deconv_head + + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: + """Forward the network. The input is multi scale feature maps and the + output is the heatmap. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + pred_x (Tensor): 1d representation of x. + pred_y (Tensor): 1d representation of y. + """ + if self.deconv_head is None: + feats = feats[-1] + if self.final_layer is not None: + feats = self.final_layer(feats) + else: + feats = self.deconv_head(feats) + + # flatten the output heatmap + x = torch.flatten(feats, 2) + + pred_x = self.mlp_head_x(x) + pred_y = self.mlp_head_y(x) + + return pred_x, pred_y + + def predict( + self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}, + ) -> InstanceList: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + List[InstanceData]: The pose predictions, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + - keypoint_x_labels (np.ndarray, optional): The predicted 1-D + intensity distribution in the x direction + - keypoint_y_labels (np.ndarray, optional): The predicted 1-D + intensity distribution in the y direction + """ + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + + _batch_pred_x, _batch_pred_y = self.forward(_feats) + + _batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip) + _batch_pred_x_flip, _batch_pred_y_flip = flip_vectors( + _batch_pred_x_flip, + _batch_pred_y_flip, + flip_indices=flip_indices) + + batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5 + batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5 + else: + batch_pred_x, batch_pred_y = self.forward(feats) + + preds = self.decode((batch_pred_x, batch_pred_y)) + + if test_cfg.get('output_heatmaps', False): + rank, _ = get_dist_info() + if rank == 0: + warnings.warn('The predicted simcc values are normalized for ' + 'visualization. This may cause discrepancy ' + 'between the keypoint scores and the 1D heatmaps' + '.') + + # normalize the predicted 1d distribution + sigma = self.decoder.sigma + batch_pred_x = get_simcc_normalized(batch_pred_x, sigma[0]) + batch_pred_y = get_simcc_normalized(batch_pred_y, sigma[1]) + + B, K, _ = batch_pred_x.shape + # B, K, Wx -> B, K, Wx, 1 + x = batch_pred_x.reshape(B, K, 1, -1) + # B, K, Wy -> B, K, 1, Wy + y = batch_pred_y.reshape(B, K, -1, 1) + # B, K, Wx, Wy + batch_heatmaps = torch.matmul(y, x) + pred_fields = [ + PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() + ] + + for pred_instances, pred_x, pred_y in zip(preds, + to_numpy(batch_pred_x), + to_numpy(batch_pred_y)): + + pred_instances.keypoint_x_labels = pred_x[None] + pred_instances.keypoint_y_labels = pred_y[None] + + return preds, pred_fields + else: + return preds + + def loss( + self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}, + ) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_x, pred_y = self.forward(feats) + + gt_x = torch.cat([ + d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples + ], + dim=0) + gt_y = torch.cat([ + d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples + ], + dim=0) + keypoint_weights = torch.cat( + [ + d.gt_instance_labels.keypoint_weights + for d in batch_data_samples + ], + dim=0, + ) + + pred_simcc = (pred_x, pred_y) + gt_simcc = (gt_x, gt_y) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights) + + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = simcc_pck_accuracy( + output=to_numpy(pred_simcc), + target=to_numpy(gt_simcc), + simcc_split_ratio=self.simcc_split_ratio, + mask=to_numpy(keypoint_weights) > 0, + ) + + acc_pose = torch.tensor(avg_acc, device=gt_x.device) + losses.update(acc_pose=acc_pose) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [ + dict( + type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1), + dict(type='Normal', layer=['Linear'], std=0.01, bias=0), + ] + return init_cfg diff --git a/mmpose/models/heads/heatmap_heads/__init__.py b/mmpose/models/heads/heatmap_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b482216b36f61ceb66aae8974ae178a8455d5022 --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ae_head import AssociativeEmbeddingHead +from .cid_head import CIDHead +from .cpm_head import CPMHead +from .heatmap_head import HeatmapHead +from .mspn_head import MSPNHead +from .vipnas_head import ViPNASHead + +__all__ = [ + 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead', + 'AssociativeEmbeddingHead', 'CIDHead' +] diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06f989363ba4a109f08fe2c797f0aa2faa817ca5 Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/ae_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/ae_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..113225af5bbeff20f60677c1c806dc498c3adf1c Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/ae_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/cid_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/cid_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6674128484615944762f461f506bdc2911336969 Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/cid_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/cpm_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/cpm_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44883ef860dda1bd001a10eb60bd4588fd5dc56c Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/cpm_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/heatmap_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/heatmap_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0b066f379c0fe1630dae375181429f7af532c8 Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/heatmap_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/mspn_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/mspn_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..232487516041aac335ad248d1004010ebeac6b1e Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/mspn_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/vipnas_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/vipnas_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58ee376a83c13788611754e5e93e8d955f39b442 Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/vipnas_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/heatmap_heads/ae_head.py b/mmpose/models/heads/heatmap_heads/ae_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bd12d57a333bbab91dfdcb11afddd0cbd351b945 --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/ae_head.py @@ -0,0 +1,291 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +from mmengine.structures import PixelData +from mmengine.utils import is_list_of +from torch import Tensor + +from mmpose.models.utils.tta import aggregate_heatmaps, flip_heatmaps +from mmpose.registry import MODELS +from mmpose.utils.typing import (ConfigType, Features, OptConfigType, + OptSampleList, Predictions) +from .heatmap_head import HeatmapHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class AssociativeEmbeddingHead(HeatmapHead): + + def __init__(self, + in_channels: Union[int, Sequence[int]], + num_keypoints: int, + tag_dim: int = 1, + tag_per_keypoint: bool = True, + deconv_out_channels: OptIntSeq = (256, 256, 256), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + keypoint_loss: ConfigType = dict(type='KeypointMSELoss'), + tag_loss: ConfigType = dict(type='AssociativeEmbeddingLoss'), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if tag_per_keypoint: + out_channels = num_keypoints * (1 + tag_dim) + else: + out_channels = num_keypoints + tag_dim + + loss = dict( + type='CombinedLoss', + losses=dict(keypoint_loss=keypoint_loss, tag_loss=tag_loss)) + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + deconv_out_channels=deconv_out_channels, + deconv_kernel_sizes=deconv_kernel_sizes, + conv_out_channels=conv_out_channels, + conv_kernel_sizes=conv_kernel_sizes, + final_layer=final_layer, + loss=loss, + decoder=decoder, + init_cfg=init_cfg) + + self.num_keypoints = num_keypoints + self.tag_dim = tag_dim + self.tag_per_keypoint = tag_per_keypoint + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features. + + Args: + feats (Features): The features which could be in following forms: + + - Tuple[Tensor]: multi-stage features from the backbone + - List[Tuple[Tensor]]: multiple features for TTA where either + `flip_test` or `multiscale_test` is applied + - List[List[Tuple[Tensor]]]: multiple features for TTA where + both `flip_test` and `multiscale_test` are applied + + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + # test configs + multiscale_test = test_cfg.get('multiscale_test', False) + flip_test = test_cfg.get('flip_test', False) + shift_heatmap = test_cfg.get('shift_heatmap', False) + align_corners = test_cfg.get('align_corners', False) + restore_heatmap_size = test_cfg.get('restore_heatmap_size', False) + output_heatmaps = test_cfg.get('output_heatmaps', False) + + # enable multi-scale test + if multiscale_test: + # TTA: multi-scale test + assert is_list_of(feats, list if flip_test else tuple) + else: + assert is_list_of(feats, tuple if flip_test else Tensor) + feats = [feats] + + # resize heatmaps to align with with input size + if restore_heatmap_size: + img_shape = batch_data_samples[0].metainfo['img_shape'] + assert all(d.metainfo['img_shape'] == img_shape + for d in batch_data_samples) + img_h, img_w = img_shape + heatmap_size = (img_w, img_h) + else: + heatmap_size = None + + multiscale_heatmaps = [] + multiscale_tags = [] + + for scale_idx, _feats in enumerate(feats): + if not flip_test: + _heatmaps, _tags = self.forward(_feats) + + else: + # TTA: flip test + assert isinstance(_feats, list) and len(_feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + # original + _feats_orig, _feats_flip = _feats + _heatmaps_orig, _tags_orig = self.forward(_feats_orig) + + # flipped + _heatmaps_flip, _tags_flip = self.forward(_feats_flip) + _heatmaps_flip = flip_heatmaps( + _heatmaps_flip, + flip_mode='heatmap', + flip_indices=flip_indices, + shift_heatmap=shift_heatmap) + _tags_flip = self._flip_tags( + _tags_flip, + flip_indices=flip_indices, + shift_heatmap=shift_heatmap) + + # aggregated heatmaps + _heatmaps = aggregate_heatmaps( + [_heatmaps_orig, _heatmaps_flip], + size=heatmap_size, + align_corners=align_corners, + mode='average') + + # aggregated tags (only at original scale) + if scale_idx == 0: + _tags = aggregate_heatmaps([_tags_orig, _tags_flip], + size=heatmap_size, + align_corners=align_corners, + mode='concat') + else: + _tags = None + + multiscale_heatmaps.append(_heatmaps) + multiscale_tags.append(_tags) + + # aggregate multi-scale heatmaps + if len(feats) > 1: + batch_heatmaps = aggregate_heatmaps( + multiscale_heatmaps, + align_corners=align_corners, + mode='average') + else: + batch_heatmaps = multiscale_heatmaps[0] + # only keep tags at original scale + batch_tags = multiscale_tags[0] + + batch_outputs = tuple([batch_heatmaps, batch_tags]) + preds = self.decode(batch_outputs) + + if output_heatmaps: + pred_fields = [] + for _heatmaps, _tags in zip(batch_heatmaps.detach(), + batch_tags.detach()): + pred_fields.append(PixelData(heatmaps=_heatmaps, tags=_tags)) + + return preds, pred_fields + else: + return preds + + def _flip_tags(self, + tags: Tensor, + flip_indices: List[int], + shift_heatmap: bool = True): + """Flip the tagging heatmaps horizontally for test-time augmentation. + + Args: + tags (Tensor): batched tagging heatmaps to flip + flip_indices (List[int]): The indices of each keypoint's symmetric + keypoint + shift_heatmap (bool): Shift the flipped heatmaps to align with the + original heatmaps and improve accuracy. Defaults to ``True`` + + Returns: + Tensor: flipped tagging heatmaps + """ + B, C, H, W = tags.shape + K = self.num_keypoints + L = self.tag_dim + + tags = tags.flip(-1) + + if self.tag_per_keypoint: + assert C == K * L + tags = tags.view(B, L, K, H, W) + tags = tags[:, :, flip_indices] + tags = tags.view(B, C, H, W) + + if shift_heatmap: + tags[..., 1:] = tags[..., :-1].clone() + + return tags + + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: + """Forward the network. The input is multi scale feature maps and the + output is the heatmaps and tags. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + tuple: + - heatmaps (Tensor): output heatmaps + - tags (Tensor): output tags + """ + + output = super().forward(feats) + heatmaps = output[:, :self.num_keypoints] + tags = output[:, self.num_keypoints:] + return heatmaps, tags + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + feats (Tuple[Tensor]): The multi-stage features + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + train_cfg (dict): The runtime config for training process. + Defaults to {} + + Returns: + dict: A dictionary of losses. + """ + pred_heatmaps, pred_tags = self.forward(feats) + + if not self.tag_per_keypoint: + pred_tags = pred_tags.repeat((1, self.num_keypoints, 1, 1)) + + gt_heatmaps = torch.stack( + [d.gt_fields.heatmaps for d in batch_data_samples]) + gt_masks = torch.stack( + [d.gt_fields.heatmap_mask for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + keypoint_indices = [ + d.gt_instance_labels.keypoint_indices for d in batch_data_samples + ] + + loss_kpt = self.loss_module.keypoint_loss(pred_heatmaps, gt_heatmaps, + keypoint_weights, gt_masks) + + loss_pull, loss_push = self.loss_module.tag_loss( + pred_tags, keypoint_indices) + + losses = { + 'loss_kpt': loss_kpt, + 'loss_pull': loss_pull, + 'loss_push': loss_push + } + + return losses diff --git a/mmpose/models/heads/heatmap_heads/cid_head.py b/mmpose/models/heads/heatmap_heads/cid_head.py new file mode 100644 index 0000000000000000000000000000000000000000..39e0211a3e135c1c101c14e37956528d3330ca1b --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/cid_head.py @@ -0,0 +1,743 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer +from mmengine.model import BaseModule, ModuleDict +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmpose.models.utils.tta import flip_heatmaps +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.typing import (ConfigType, Features, OptConfigType, + OptSampleList, Predictions) +from ..base_head import BaseHead + + +def smooth_heatmaps(heatmaps: Tensor, blur_kernel_size: int) -> Tensor: + """Smooth the heatmaps by blurring and averaging. + + Args: + heatmaps (Tensor): The heatmaps to smooth. + blur_kernel_size (int): The kernel size for blurring the heatmaps. + + Returns: + Tensor: The smoothed heatmaps. + """ + smoothed_heatmaps = torch.nn.functional.avg_pool2d( + heatmaps, blur_kernel_size, 1, (blur_kernel_size - 1) // 2) + smoothed_heatmaps = (heatmaps + smoothed_heatmaps) / 2.0 + return smoothed_heatmaps + + +class TruncSigmoid(nn.Sigmoid): + """A sigmoid activation function that truncates the output to the given + range. + + Args: + min (float, optional): The minimum value to clamp the output to. + Defaults to 0.0 + max (float, optional): The maximum value to clamp the output to. + Defaults to 1.0 + """ + + def __init__(self, min: float = 0.0, max: float = 1.0): + super(TruncSigmoid, self).__init__() + self.min = min + self.max = max + + def forward(self, input: Tensor) -> Tensor: + """Computes the truncated sigmoid activation of the input tensor.""" + output = torch.sigmoid(input) + output = output.clamp(min=self.min, max=self.max) + return output + + +class IIAModule(BaseModule): + """Instance Information Abstraction module introduced in `CID`. This module + extracts the feature representation vectors for each instance. + + Args: + in_channels (int): Number of channels in the input feature tensor + out_channels (int): Number of channels of the output heatmaps + clamp_delta (float, optional): A small value that prevents the sigmoid + activation from becoming saturated. Defaults to 1e-4. + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + clamp_delta: float = 1e-4, + init_cfg: OptConfigType = None, + ): + super().__init__(init_cfg=init_cfg) + + self.keypoint_root_conv = build_conv_layer( + dict( + type='Conv2d', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1)) + self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta) + + def forward(self, feats: Tensor): + heatmaps = self.keypoint_root_conv(feats) + heatmaps = self.sigmoid(heatmaps) + return heatmaps + + def _sample_feats(self, feats: Tensor, indices: Tensor) -> Tensor: + """Extract feature vectors at the specified indices from the input + feature map. + + Args: + feats (Tensor): Input feature map. + indices (Tensor): Indices of the feature vectors to extract. + + Returns: + Tensor: Extracted feature vectors. + """ + assert indices.dtype == torch.long + if indices.shape[1] == 3: + b, w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)] + instance_feats = feats[b, :, h, w] + elif indices.shape[1] == 2: + w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)] + instance_feats = feats[:, :, h, w] + instance_feats = instance_feats.permute(0, 2, 1) + instance_feats = instance_feats.reshape(-1, + instance_feats.shape[-1]) + + else: + raise ValueError(f'`indices` should have 2 or 3 channels, ' + f'but got f{indices.shape[1]}') + return instance_feats + + def _hierarchical_pool(self, heatmaps: Tensor) -> Tensor: + """Conduct max pooling on the input heatmaps with different kernel size + according to the input size. + + Args: + heatmaps (Tensor): Input heatmaps. + + Returns: + Tensor: Result of hierarchical pooling. + """ + map_size = (heatmaps.shape[-1] + heatmaps.shape[-2]) / 2.0 + if map_size > 300: + maxm = torch.nn.functional.max_pool2d(heatmaps, 7, 1, 3) + elif map_size > 200: + maxm = torch.nn.functional.max_pool2d(heatmaps, 5, 1, 2) + else: + maxm = torch.nn.functional.max_pool2d(heatmaps, 3, 1, 1) + return maxm + + def forward_train(self, feats: Tensor, instance_coords: Tensor, + instance_imgids: Tensor) -> Tuple[Tensor, Tensor]: + """Forward pass during training. + + Args: + feats (Tensor): Input feature tensor. + instance_coords (Tensor): Coordinates of the instance roots. + instance_imgids (Tensor): Sample indices of each instances + in the batch. + + Returns: + Tuple[Tensor, Tensor]: Extracted feature vectors and heatmaps + for the instances. + """ + heatmaps = self.forward(feats) + indices = torch.cat((instance_imgids[:, None], instance_coords), dim=1) + instance_feats = self._sample_feats(feats, indices) + + return instance_feats, heatmaps + + def forward_test( + self, feats: Tensor, test_cfg: Dict + ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + """Forward pass during testing. + + Args: + feats (Tensor): Input feature tensor. + test_cfg (Dict): Testing configuration, including: + - blur_kernel_size (int, optional): Kernel size for blurring + the heatmaps. Defaults to 3. + - max_instances (int, optional): Maximum number of instances + to extract. Defaults to 30. + - score_threshold (float, optional): Minimum score for + extracting an instance. Defaults to 0.01. + - flip_test (bool, optional): Whether to compute the average + of the heatmaps across the batch dimension. + Defaults to False. + + Returns: + A tuple of Tensor including extracted feature vectors, + coordinates, and scores of the instances. Any of these can be + empty Tensor if no instances are extracted. + """ + blur_kernel_size = test_cfg.get('blur_kernel_size', 3) + max_instances = test_cfg.get('max_instances', 30) + score_threshold = test_cfg.get('score_threshold', 0.01) + H, W = feats.shape[-2:] + + # compute heatmaps + heatmaps = self.forward(feats).narrow(1, -1, 1) + if test_cfg.get('flip_test', False): + heatmaps = heatmaps.mean(dim=0, keepdims=True) + smoothed_heatmaps = smooth_heatmaps(heatmaps, blur_kernel_size) + + # decode heatmaps + maximums = self._hierarchical_pool(smoothed_heatmaps) + maximums = torch.eq(maximums, smoothed_heatmaps).float() + maximums = (smoothed_heatmaps * maximums).reshape(-1) + scores, pos_ind = maximums.topk(max_instances, dim=0) + select_ind = (scores > (score_threshold)).nonzero().squeeze(1) + scores, pos_ind = scores[select_ind], pos_ind[select_ind] + + # sample feature vectors from feature map + instance_coords = torch.stack((pos_ind % W, pos_ind // W), dim=1) + instance_feats = self._sample_feats(feats, instance_coords) + + return instance_feats, instance_coords, scores + + +class ChannelAttention(nn.Module): + """Channel-wise attention module introduced in `CID`. + + Args: + in_channels (int): The number of channels of the input instance + vectors. + out_channels (int): The number of channels of the transformed instance + vectors. + """ + + def __init__(self, in_channels: int, out_channels: int): + super(ChannelAttention, self).__init__() + self.atn = nn.Linear(in_channels, out_channels) + + def forward(self, global_feats: Tensor, instance_feats: Tensor) -> Tensor: + """Applies attention to the channel dimension of the input tensor.""" + + instance_feats = self.atn(instance_feats).unsqueeze(2).unsqueeze(3) + return global_feats * instance_feats + + +class SpatialAttention(nn.Module): + """Spatial-wise attention module introduced in `CID`. + + Args: + in_channels (int): The number of channels of the input instance + vectors. + out_channels (int): The number of channels of the transformed instance + vectors. + """ + + def __init__(self, in_channels, out_channels): + super(SpatialAttention, self).__init__() + self.atn = nn.Linear(in_channels, out_channels) + self.feat_stride = 4 + self.conv = nn.Conv2d(3, 1, 5, 1, 2) + + def _get_pixel_coords(self, heatmap_size: Tuple, device: str = 'cpu'): + """Get pixel coordinates for each element in the heatmap. + + Args: + heatmap_size (tuple): Size of the heatmap in (W, H) format. + device (str): Device to put the resulting tensor on. + + Returns: + Tensor of shape (batch_size, num_pixels, 2) containing the pixel + coordinates for each element in the heatmap. + """ + w, h = heatmap_size + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) + pixel_coords = torch.stack((x, y), dim=-1).reshape(-1, 2) + pixel_coords = pixel_coords.float().to(device) + 0.5 + return pixel_coords + + def forward(self, global_feats: Tensor, instance_feats: Tensor, + instance_coords: Tensor) -> Tensor: + """Perform spatial attention. + + Args: + global_feats (Tensor): Tensor containing the global features. + instance_feats (Tensor): Tensor containing the instance feature + vectors. + instance_coords (Tensor): Tensor containing the root coordinates + of the instances. + + Returns: + Tensor containing the modulated global features. + """ + B, C, H, W = global_feats.size() + + instance_feats = self.atn(instance_feats).reshape(B, C, 1, 1) + feats = global_feats * instance_feats.expand_as(global_feats) + fsum = torch.sum(feats, dim=1, keepdim=True) + + pixel_coords = self._get_pixel_coords((W, H), feats.device) + relative_coords = instance_coords.reshape( + -1, 1, 2) - pixel_coords.reshape(1, -1, 2) + relative_coords = relative_coords.permute(0, 2, 1) / 32.0 + relative_coords = relative_coords.reshape(B, 2, H, W) + + input_feats = torch.cat((fsum, relative_coords), dim=1) + mask = self.conv(input_feats).sigmoid() + return global_feats * mask + + +class GFDModule(BaseModule): + """Global Feature Decoupling module introduced in `CID`. This module + extracts the decoupled heatmaps for each instance. + + Args: + in_channels (int): Number of channels in the input feature map + out_channels (int): Number of channels of the output heatmaps + for each instance + gfd_channels (int): Number of channels in the transformed feature map + clamp_delta (float, optional): A small value that prevents the sigmoid + activation from becoming saturated. Defaults to 1e-4. + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + gfd_channels: int, + clamp_delta: float = 1e-4, + init_cfg: OptConfigType = None, + ): + super().__init__(init_cfg=init_cfg) + + self.conv_down = build_conv_layer( + dict( + type='Conv2d', + in_channels=in_channels, + out_channels=gfd_channels, + kernel_size=1)) + + self.channel_attention = ChannelAttention(in_channels, gfd_channels) + self.spatial_attention = SpatialAttention(in_channels, gfd_channels) + self.fuse_attention = build_conv_layer( + dict( + type='Conv2d', + in_channels=gfd_channels * 2, + out_channels=gfd_channels, + kernel_size=1)) + self.heatmap_conv = build_conv_layer( + dict( + type='Conv2d', + in_channels=gfd_channels, + out_channels=out_channels, + kernel_size=1)) + self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta) + + def forward( + self, + feats: Tensor, + instance_feats: Tensor, + instance_coords: Tensor, + instance_imgids: Tensor, + ) -> Tensor: + """Extract decoupled heatmaps for each instance. + + Args: + feats (Tensor): Input feature maps. + instance_feats (Tensor): Tensor containing the instance feature + vectors. + instance_coords (Tensor): Tensor containing the root coordinates + of the instances. + instance_imgids (Tensor): Sample indices of each instances + in the batch. + + Returns: + A tensor containing decoupled heatmaps. + """ + + global_feats = self.conv_down(feats) + global_feats = global_feats[instance_imgids] + cond_instance_feats = torch.cat( + (self.channel_attention(global_feats, instance_feats), + self.spatial_attention(global_feats, instance_feats, + instance_coords)), + dim=1) + + cond_instance_feats = self.fuse_attention(cond_instance_feats) + cond_instance_feats = torch.nn.functional.relu(cond_instance_feats) + cond_instance_feats = self.heatmap_conv(cond_instance_feats) + heatmaps = self.sigmoid(cond_instance_feats) + + return heatmaps + + +@MODELS.register_module() +class CIDHead(BaseHead): + """Contextual Instance Decoupling head introduced in `Contextual Instance + Decoupling for Robust Multi-Person Pose Estimation (CID)`_ by Wang et al + (2022). The head is composed of an Instance Information Abstraction (IIA) + module and a Global Feature Decoupling (GFD) module. + + Args: + in_channels (int | Sequence[int]): Number of channels in the input + feature map + num_keypoints (int): Number of keypoints + gfd_channels (int): Number of filters in GFD module + max_train_instances (int): Maximum number of instances in a batch + during training. Defaults to 200 + heatmap_loss (Config): Config of the heatmap loss. Defaults to use + :class:`KeypointMSELoss` + coupled_heatmap_loss (Config): Config of the loss for coupled heatmaps. + Defaults to use :class:`SoftWeightSmoothL1Loss` + decoupled_heatmap_loss (Config): Config of the loss for decoupled + heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss` + contrastive_loss (Config): Config of the contrastive loss for + representation vectors of instances. Defaults to use + :class:`InfoNCELoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_ + Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_ + CVPR_2022_paper.html + """ + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + gfd_channels: int, + num_keypoints: int, + prior_prob: float = 0.01, + coupled_heatmap_loss: OptConfigType = dict( + type='FocalHeatmapLoss'), + decoupled_heatmap_loss: OptConfigType = dict( + type='FocalHeatmapLoss'), + contrastive_loss: OptConfigType = dict(type='InfoNCELoss'), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_keypoints = num_keypoints + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # build sub-modules + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.iia_module = IIAModule( + in_channels, + num_keypoints + 1, + init_cfg=init_cfg + [ + dict( + type='Normal', + layer=['Conv2d', 'Linear'], + std=0.001, + override=dict( + name='keypoint_root_conv', + type='Normal', + std=0.001, + bias=bias_value)) + ]) + self.gfd_module = GFDModule( + in_channels, + num_keypoints, + gfd_channels, + init_cfg=init_cfg + [ + dict( + type='Normal', + layer=['Conv2d', 'Linear'], + std=0.001, + override=dict( + name='heatmap_conv', + type='Normal', + std=0.001, + bias=bias_value)) + ]) + + # build losses + self.loss_module = ModuleDict( + dict( + heatmap_coupled=MODELS.build(coupled_heatmap_loss), + heatmap_decoupled=MODELS.build(decoupled_heatmap_loss), + contrastive=MODELS.build(contrastive_loss), + )) + + # Register the hook to automatically convert old version state dicts + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + @property + def default_init_cfg(self): + init_cfg = [ + dict(type='Normal', layer=['Conv2d', 'Linear'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1) + ] + return init_cfg + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the heatmap. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output heatmap. + """ + feats = feats[-1] + instance_info = self.iia_module.forward_test(feats, {}) + instance_feats, instance_coords, instance_scores = instance_info + instance_imgids = torch.zeros( + instance_coords.size(0), dtype=torch.long, device=feats.device) + instance_heatmaps = self.gfd_module(feats, instance_feats, + instance_coords, instance_imgids) + + return instance_heatmaps + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + metainfo = batch_data_samples[0].metainfo + + if test_cfg.get('flip_test', False): + assert isinstance(feats, list) and len(feats) == 2 + + feats_flipped = flip_heatmaps(feats[1][-1], shift_heatmap=False) + feats = torch.cat((feats[0][-1], feats_flipped)) + else: + feats = feats[-1] + + instance_info = self.iia_module.forward_test(feats, test_cfg) + instance_feats, instance_coords, instance_scores = instance_info + if len(instance_coords) > 0: + instance_imgids = torch.zeros( + instance_coords.size(0), dtype=torch.long, device=feats.device) + if test_cfg.get('flip_test', False): + instance_coords = torch.cat((instance_coords, instance_coords)) + instance_imgids = torch.cat( + (instance_imgids, instance_imgids + 1)) + instance_heatmaps = self.gfd_module(feats, instance_feats, + instance_coords, + instance_imgids) + if test_cfg.get('flip_test', False): + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + instance_heatmaps, instance_heatmaps_flip = torch.chunk( + instance_heatmaps, 2, dim=0) + instance_heatmaps_flip = \ + instance_heatmaps_flip[:, flip_indices, :, :] + instance_heatmaps = (instance_heatmaps + + instance_heatmaps_flip) / 2.0 + instance_heatmaps = smooth_heatmaps( + instance_heatmaps, test_cfg.get('blur_kernel_size', 3)) + + preds = self.decode((instance_heatmaps, instance_scores[:, None])) + preds = InstanceData.cat(preds) + preds.keypoints[..., 0] += metainfo['input_size'][ + 0] / instance_heatmaps.shape[-1] / 2.0 + preds.keypoints[..., 1] += metainfo['input_size'][ + 1] / instance_heatmaps.shape[-2] / 2.0 + preds = [preds] + + else: + preds = [ + InstanceData( + keypoints=np.empty((0, self.num_keypoints, 2)), + keypoint_scores=np.empty((0, self.num_keypoints))) + ] + instance_heatmaps = torch.empty(0, self.num_keypoints, + *feats.shape[-2:]) + + if test_cfg.get('output_heatmaps', False): + pred_fields = [ + PixelData( + heatmaps=instance_heatmaps.reshape( + -1, *instance_heatmaps.shape[-2:])) + ] + return preds, pred_fields + else: + return preds + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + feats (Tuple[Tensor]): The multi-stage features + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + train_cfg (dict): The runtime config for training process. + Defaults to {} + + Returns: + dict: A dictionary of losses. + """ + + # load targets + gt_heatmaps, gt_instance_coords, keypoint_weights = [], [], [] + heatmap_mask = [] + instance_imgids, gt_instance_heatmaps = [], [] + for i, d in enumerate(batch_data_samples): + gt_heatmaps.append(d.gt_fields.heatmaps) + gt_instance_coords.append(d.gt_instance_labels.instance_coords) + keypoint_weights.append(d.gt_instance_labels.keypoint_weights) + instance_imgids.append( + torch.ones( + len(d.gt_instance_labels.instance_coords), + dtype=torch.long) * i) + + instance_heatmaps = d.gt_fields.instance_heatmaps.reshape( + -1, self.num_keypoints, + *d.gt_fields.instance_heatmaps.shape[1:]) + gt_instance_heatmaps.append(instance_heatmaps) + + if 'heatmap_mask' in d.gt_fields: + heatmap_mask.append(d.gt_fields.heatmap_mask) + + gt_heatmaps = torch.stack(gt_heatmaps) + heatmap_mask = torch.stack(heatmap_mask) if heatmap_mask else None + + gt_instance_coords = torch.cat(gt_instance_coords, dim=0) + gt_instance_heatmaps = torch.cat(gt_instance_heatmaps, dim=0) + keypoint_weights = torch.cat(keypoint_weights, dim=0) + instance_imgids = torch.cat(instance_imgids).to(gt_heatmaps.device) + + # feed-forward + feats = feats[-1] + pred_instance_feats, pred_heatmaps = self.iia_module.forward_train( + feats, gt_instance_coords, instance_imgids) + + # conpute contrastive loss + contrastive_loss = 0 + for i in range(len(batch_data_samples)): + pred_instance_feat = pred_instance_feats[instance_imgids == i] + contrastive_loss += self.loss_module['contrastive']( + pred_instance_feat) + contrastive_loss = contrastive_loss / max(1, len(instance_imgids)) + + # limit the number of instances + max_train_instances = train_cfg.get('max_train_instances', -1) + if (max_train_instances > 0 + and len(instance_imgids) > max_train_instances): + selected_indices = torch.randperm( + len(instance_imgids), + device=gt_heatmaps.device, + dtype=torch.long)[:max_train_instances] + gt_instance_coords = gt_instance_coords[selected_indices] + keypoint_weights = keypoint_weights[selected_indices] + gt_instance_heatmaps = gt_instance_heatmaps[selected_indices] + instance_imgids = instance_imgids[selected_indices] + pred_instance_feats = pred_instance_feats[selected_indices] + + # calculate the decoupled heatmaps for each instance + pred_instance_heatmaps = self.gfd_module(feats, pred_instance_feats, + gt_instance_coords, + instance_imgids) + + # calculate losses + losses = { + 'loss/heatmap_coupled': + self.loss_module['heatmap_coupled'](pred_heatmaps, gt_heatmaps, + None, heatmap_mask) + } + if len(instance_imgids) > 0: + losses.update({ + 'loss/heatmap_decoupled': + self.loss_module['heatmap_decoupled'](pred_instance_heatmaps, + gt_instance_heatmaps, + keypoint_weights), + 'loss/contrastive': + contrastive_loss + }) + + return losses + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to convert old-version state dict of + :class:`CIDHead` (before MMPose v1.0.0) to a compatible format + of :class:`CIDHead`. + + The hook will be automatically registered during initialization. + """ + version = local_meta.get('version', None) + if version and version >= self._version: + return + + # convert old-version state dict + keys = list(state_dict.keys()) + for k in keys: + if 'keypoint_center_conv' in k: + v = state_dict.pop(k) + k = k.replace('keypoint_center_conv', + 'iia_module.keypoint_root_conv') + state_dict[k] = v + + if 'conv_down' in k: + v = state_dict.pop(k) + k = k.replace('conv_down', 'gfd_module.conv_down') + state_dict[k] = v + + if 'c_attn' in k: + v = state_dict.pop(k) + k = k.replace('c_attn', 'gfd_module.channel_attention') + state_dict[k] = v + + if 's_attn' in k: + v = state_dict.pop(k) + k = k.replace('s_attn', 'gfd_module.spatial_attention') + state_dict[k] = v + + if 'fuse_attn' in k: + v = state_dict.pop(k) + k = k.replace('fuse_attn', 'gfd_module.fuse_attention') + state_dict[k] = v + + if 'heatmap_conv' in k: + v = state_dict.pop(k) + k = k.replace('heatmap_conv', 'gfd_module.heatmap_conv') + state_dict[k] = v diff --git a/mmpose/models/heads/heatmap_heads/cpm_head.py b/mmpose/models/heads/heatmap_heads/cpm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba46357ec5cf72b29b43635a53354f2ed2fd048 --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/cpm_head.py @@ -0,0 +1,307 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import torch +from mmcv.cnn import build_conv_layer, build_upsample_layer +from mmengine.structures import PixelData +from torch import Tensor, nn + +from mmpose.evaluation.functional import pose_pck_accuracy +from mmpose.models.utils.tta import flip_heatmaps +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (Features, MultiConfig, OptConfigType, + OptSampleList, Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class CPMHead(BaseHead): + """Multi-stage heatmap head introduced in `Convolutional Pose Machines`_ by + Wei et al (2016) and used by `Stacked Hourglass Networks`_ by Newell et al + (2016). The head consists of multiple branches, each of which has some + deconv layers and a simple conv2d layer. + + Args: + in_channels (int | Sequence[int]): Number of channels in the input + feature maps. + out_channels (int): Number of channels in the output heatmaps. + num_stages (int): Number of stages. + deconv_out_channels (Sequence[int], optional): The output channel + number of each deconv layer. Defaults to ``(256, 256, 256)`` + deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size + of each deconv layer. Each element should be either an integer for + both height and width dimensions, or a tuple of two integers for + the height and the width dimension respectively. + Defaults to ``(4, 4, 4)`` + final_layer (dict): Arguments of the final Conv2d layer. + Defaults to ``dict(kernel_size=1)`` + loss (Config | List[Config]): Config of the keypoint loss of different + stages. Defaults to use :class:`KeypointMSELoss`. + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`Convolutional Pose Machines`: https://arxiv.org/abs/1602.00134 + .. _`Stacked Hourglass Networks`: https://arxiv.org/abs/1603.06937 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + num_stages: int, + deconv_out_channels: OptIntSeq = None, + deconv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + loss: MultiConfig = dict( + type='KeypointMSELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + super().__init__(init_cfg) + + self.num_stages = num_stages + self.in_channels = in_channels + self.out_channels = out_channels + + if isinstance(loss, list): + if len(loss) != num_stages: + raise ValueError( + f'The length of loss_module({len(loss)}) did not match ' + f'`num_stages`({num_stages})') + self.loss_module = nn.ModuleList( + MODELS.build(_loss) for _loss in loss) + else: + self.loss_module = MODELS.build(loss) + + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # build multi-stage deconv layers + self.multi_deconv_layers = nn.ModuleList([]) + if deconv_out_channels: + if deconv_kernel_sizes is None or len(deconv_out_channels) != len( + deconv_kernel_sizes): + raise ValueError( + '"deconv_out_channels" and "deconv_kernel_sizes" should ' + 'be integer sequences with the same length. Got ' + f'mismatched lengths {deconv_out_channels} and ' + f'{deconv_kernel_sizes}') + + for _ in range(self.num_stages): + deconv_layers = self._make_deconv_layers( + in_channels=in_channels, + layer_out_channels=deconv_out_channels, + layer_kernel_sizes=deconv_kernel_sizes, + ) + self.multi_deconv_layers.append(deconv_layers) + in_channels = deconv_out_channels[-1] + else: + for _ in range(self.num_stages): + self.multi_deconv_layers.append(nn.Identity()) + + # build multi-stage final layers + self.multi_final_layers = nn.ModuleList([]) + if final_layer is not None: + cfg = dict( + type='Conv2d', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1) + cfg.update(final_layer) + for _ in range(self.num_stages): + self.multi_final_layers.append(build_conv_layer(cfg)) + else: + for _ in range(self.num_stages): + self.multi_final_layers.append(nn.Identity()) + + @property + def default_init_cfg(self): + init_cfg = [ + dict( + type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1) + ] + return init_cfg + + def _make_deconv_layers(self, in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int]) -> nn.Module: + """Create deconvolutional layers by given parameters.""" + + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, + layer_kernel_sizes): + if kernel_size == 4: + padding = 1 + output_padding = 0 + elif kernel_size == 3: + padding = 1 + output_padding = 1 + elif kernel_size == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size} for' + 'deconvlutional layers in ' + f'{self.__class__.__name__}') + cfg = dict( + type='deconv', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False) + layers.append(build_upsample_layer(cfg)) + layers.append(nn.BatchNorm2d(num_features=out_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, feats: Sequence[Tensor]) -> List[Tensor]: + """Forward the network. The input is multi-stage feature maps and the + output is a list of heatmaps from multiple stages. + + Args: + feats (Sequence[Tensor]): Multi-stage feature maps. + + Returns: + List[Tensor]: A list of output heatmaps from multiple stages. + """ + out = [] + assert len(feats) == self.num_stages, ( + f'The length of feature maps did not match the ' + f'`num_stages` in {self.__class__.__name__}') + for i in range(self.num_stages): + y = self.multi_deconv_layers[i](feats[i]) + y = self.multi_final_layers[i](y) + out.append(y) + + return out + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}) -> Predictions: + """Predict results from multi-stage feature maps. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + + if test_cfg.get('flip_test', False): + # TTA: flip test + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + _batch_heatmaps = self.forward(_feats)[-1] + _batch_heatmaps_flip = flip_heatmaps( + self.forward(_feats_flip)[-1], + flip_mode=test_cfg.get('flip_mode', 'heatmap'), + flip_indices=flip_indices, + shift_heatmap=test_cfg.get('shift_heatmap', False)) + batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5 + else: + multi_stage_heatmaps = self.forward(feats) + batch_heatmaps = multi_stage_heatmaps[-1] + + preds = self.decode(batch_heatmaps) + + if test_cfg.get('output_heatmaps', False): + pred_fields = [ + PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() + ] + return preds, pred_fields + else: + return preds + + def loss(self, + feats: Sequence[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + feats (Sequence[Tensor]): Multi-stage feature maps. + batch_data_samples (List[:obj:`PoseDataSample`]): The Data + Samples. It usually includes information such as + `gt_instances`. + train_cfg (Config, optional): The training config. + + Returns: + dict: A dictionary of loss components. + """ + multi_stage_pred_heatmaps = self.forward(feats) + + gt_heatmaps = torch.stack( + [d.gt_fields.heatmaps for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + # calculate losses over multiple stages + losses = dict() + for i in range(self.num_stages): + if isinstance(self.loss_module, nn.ModuleList): + # use different loss_module over different stages + loss_func = self.loss_module[i] + else: + # use the same loss_module over different stages + loss_func = self.loss_module + + # the `gt_heatmaps` and `keypoint_weights` used to calculate loss + # for different stages are the same + loss_i = loss_func(multi_stage_pred_heatmaps[i], gt_heatmaps, + keypoint_weights) + + if 'loss_kpt' not in losses: + losses['loss_kpt'] = loss_i + else: + losses['loss_kpt'] += loss_i + + # calculate accuracy + _, avg_acc, _ = pose_pck_accuracy( + output=to_numpy(multi_stage_pred_heatmaps[-1]), + target=to_numpy(gt_heatmaps), + mask=to_numpy(keypoint_weights) > 0) + + acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device) + losses.update(acc_pose=acc_pose) + + return losses diff --git a/mmpose/models/heads/heatmap_heads/heatmap_head.py b/mmpose/models/heads/heatmap_heads/heatmap_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0fa3f475b3e54ca0dd267e486d1e0424b35ab6 --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/heatmap_head.py @@ -0,0 +1,369 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import torch +from mmcv.cnn import build_conv_layer, build_upsample_layer +from mmengine.structures import PixelData +from torch import Tensor, nn + +from mmpose.evaluation.functional import pose_pck_accuracy +from mmpose.models.utils.tta import flip_heatmaps +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, Features, OptConfigType, + OptSampleList, Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class HeatmapHead(BaseHead): + """Top-down heatmap head introduced in `Simple Baselines`_ by Xiao et al + (2018). The head is composed of a few deconvolutional layers followed by a + convolutional layer to generate heatmaps from low-resolution feature maps. + + Args: + in_channels (int | Sequence[int]): Number of channels in the input + feature map + out_channels (int): Number of channels in the output heatmap + deconv_out_channels (Sequence[int], optional): The output channel + number of each deconv layer. Defaults to ``(256, 256, 256)`` + deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size + of each deconv layer. Each element should be either an integer for + both height and width dimensions, or a tuple of two integers for + the height and the width dimension respectively.Defaults to + ``(4, 4, 4)`` + conv_out_channels (Sequence[int], optional): The output channel number + of each intermediate conv layer. ``None`` means no intermediate + conv layer between deconv layers and the final conv layer. + Defaults to ``None`` + conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size + of each intermediate conv layer. Defaults to ``None`` + final_layer (dict): Arguments of the final Conv2d layer. + Defaults to ``dict(kernel_size=1)`` + loss (Config): Config of the keypoint loss. Defaults to use + :class:`KeypointMSELoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + extra (dict, optional): Extra configurations. + Defaults to ``None`` + + .. _`Simple Baselines`: https://arxiv.org/abs/1804.06208 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + deconv_out_channels: OptIntSeq = (256, 256, 256), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + loss: ConfigType = dict( + type='KeypointMSELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + if deconv_out_channels: + if deconv_kernel_sizes is None or len(deconv_out_channels) != len( + deconv_kernel_sizes): + raise ValueError( + '"deconv_out_channels" and "deconv_kernel_sizes" should ' + 'be integer sequences with the same length. Got ' + f'mismatched lengths {deconv_out_channels} and ' + f'{deconv_kernel_sizes}') + + self.deconv_layers = self._make_deconv_layers( + in_channels=in_channels, + layer_out_channels=deconv_out_channels, + layer_kernel_sizes=deconv_kernel_sizes, + ) + in_channels = deconv_out_channels[-1] + else: + self.deconv_layers = nn.Identity() + + if conv_out_channels: + if conv_kernel_sizes is None or len(conv_out_channels) != len( + conv_kernel_sizes): + raise ValueError( + '"conv_out_channels" and "conv_kernel_sizes" should ' + 'be integer sequences with the same length. Got ' + f'mismatched lengths {conv_out_channels} and ' + f'{conv_kernel_sizes}') + + self.conv_layers = self._make_conv_layers( + in_channels=in_channels, + layer_out_channels=conv_out_channels, + layer_kernel_sizes=conv_kernel_sizes) + in_channels = conv_out_channels[-1] + else: + self.conv_layers = nn.Identity() + + if final_layer is not None: + cfg = dict( + type='Conv2d', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1) + cfg.update(final_layer) + self.final_layer = build_conv_layer(cfg) + else: + self.final_layer = nn.Identity() + + # Register the hook to automatically convert old version state dicts + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def _make_conv_layers(self, in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int]) -> nn.Module: + """Create convolutional layers by given parameters.""" + + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, + layer_kernel_sizes): + padding = (kernel_size - 1) // 2 + cfg = dict( + type='Conv2d', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding) + layers.append(build_conv_layer(cfg)) + layers.append(nn.BatchNorm2d(num_features=out_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + def _make_deconv_layers(self, in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int]) -> nn.Module: + """Create deconvolutional layers by given parameters.""" + + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, + layer_kernel_sizes): + if kernel_size == 4: + padding = 1 + output_padding = 0 + elif kernel_size == 3: + padding = 1 + output_padding = 1 + elif kernel_size == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size} for' + 'deconvlutional layers in ' + f'{self.__class__.__name__}') + cfg = dict( + type='deconv', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False) + layers.append(build_upsample_layer(cfg)) + layers.append(nn.BatchNorm2d(num_features=out_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + @property + def default_init_cfg(self): + init_cfg = [ + dict( + type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1) + ] + return init_cfg + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the heatmap. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output heatmap. + """ + x = feats[-1] + + x = self.deconv_layers(x) + x = self.conv_layers(x) + x = self.final_layer(x) + + return x + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + _batch_heatmaps = self.forward(_feats) + _batch_heatmaps_flip = flip_heatmaps( + self.forward(_feats_flip), + flip_mode=test_cfg.get('flip_mode', 'heatmap'), + flip_indices=flip_indices, + shift_heatmap=test_cfg.get('shift_heatmap', False)) + batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5 + else: + batch_heatmaps = self.forward(feats) + + preds = self.decode(batch_heatmaps) + + if test_cfg.get('output_heatmaps', False): + pred_fields = [ + PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() + ] + return preds, pred_fields + else: + return preds + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + feats (Tuple[Tensor]): The multi-stage features + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + train_cfg (dict): The runtime config for training process. + Defaults to {} + + Returns: + dict: A dictionary of losses. + """ + pred_fields = self.forward(feats) + gt_heatmaps = torch.stack( + [d.gt_fields.heatmaps for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_fields, gt_heatmaps, keypoint_weights) + + losses.update(loss_kpt=loss) + + # calculate accuracy + if train_cfg.get('compute_acc', True): + _, avg_acc, _ = pose_pck_accuracy( + output=to_numpy(pred_fields), + target=to_numpy(gt_heatmaps), + mask=to_numpy(keypoint_weights) > 0) + + acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device) + losses.update(acc_pose=acc_pose) + + return losses + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to convert old-version state dict of + :class:`DeepposeRegressionHead` (before MMPose v1.0.0) to a + compatible format of :class:`RegressionHead`. + + The hook will be automatically registered during initialization. + """ + version = local_meta.get('version', None) + if version and version >= self._version: + return + + # convert old-version state dict + keys = list(state_dict.keys()) + for _k in keys: + if not _k.startswith(prefix): + continue + v = state_dict.pop(_k) + k = _k[len(prefix):] + # In old version, "final_layer" includes both intermediate + # conv layers (new "conv_layers") and final conv layers (new + # "final_layer"). + # + # If there is no intermediate conv layer, old "final_layer" will + # have keys like "final_layer.xxx", which should be still + # named "final_layer.xxx"; + # + # If there are intermediate conv layers, old "final_layer" will + # have keys like "final_layer.n.xxx", where the weights of the last + # one should be renamed "final_layer.xxx", and others should be + # renamed "conv_layers.n.xxx" + k_parts = k.split('.') + if k_parts[0] == 'final_layer': + if len(k_parts) == 3: + assert isinstance(self.conv_layers, nn.Sequential) + idx = int(k_parts[1]) + if idx < len(self.conv_layers): + # final_layer.n.xxx -> conv_layers.n.xxx + k_new = 'conv_layers.' + '.'.join(k_parts[1:]) + else: + # final_layer.n.xxx -> final_layer.xxx + k_new = 'final_layer.' + k_parts[2] + else: + # final_layer.xxx remains final_layer.xxx + k_new = k + else: + k_new = k + + state_dict[prefix + k_new] = v diff --git a/mmpose/models/heads/heatmap_heads/mspn_head.py b/mmpose/models/heads/heatmap_heads/mspn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7cddf7988bfc57cae314ef944f44b4d0d7df09 --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/mspn_head.py @@ -0,0 +1,432 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Sequence, Union + +import torch +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, Linear, + build_activation_layer, build_norm_layer) +from mmengine.structures import PixelData +from torch import Tensor, nn + +from mmpose.evaluation.functional import pose_pck_accuracy +from mmpose.models.utils.tta import flip_heatmaps +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, MultiConfig, OptConfigType, + OptSampleList, Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] +MSMUFeatures = Sequence[Sequence[Tensor]] # Multi-stage multi-unit features + + +class PRM(nn.Module): + """Pose Refine Machine. + + Please refer to "Learning Delicate Local Representations + for Multi-Person Pose Estimation" (ECCV 2020). + + Args: + out_channels (int): Number of the output channels, equals to + the number of keypoints. + norm_cfg (Config): Config to construct the norm layer. + Defaults to ``dict(type='BN')`` + """ + + def __init__(self, + out_channels: int, + norm_cfg: ConfigType = dict(type='BN')): + super().__init__() + + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + self.out_channels = out_channels + self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) + self.middle_path = nn.Sequential( + Linear(self.out_channels, self.out_channels), + build_norm_layer(dict(type='BN1d'), out_channels)[1], + build_activation_layer(dict(type='ReLU')), + Linear(self.out_channels, self.out_channels), + build_norm_layer(dict(type='BN1d'), out_channels)[1], + build_activation_layer(dict(type='ReLU')), + build_activation_layer(dict(type='Sigmoid'))) + + self.bottom_path = nn.Sequential( + ConvModule( + self.out_channels, + self.out_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + inplace=False), + DepthwiseSeparableConvModule( + self.out_channels, + 1, + kernel_size=9, + stride=1, + padding=4, + norm_cfg=norm_cfg, + inplace=False), build_activation_layer(dict(type='Sigmoid'))) + self.conv_bn_relu_prm_1 = ConvModule( + self.out_channels, + self.out_channels, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + inplace=False) + + def forward(self, x: Tensor) -> Tensor: + """Forward the network. The input heatmaps will be refined. + + Args: + x (Tensor): The input heatmaps. + + Returns: + Tensor: output heatmaps. + """ + out = self.conv_bn_relu_prm_1(x) + out_1 = out + + out_2 = self.global_pooling(out_1) + out_2 = out_2.view(out_2.size(0), -1) + out_2 = self.middle_path(out_2) + out_2 = out_2.unsqueeze(2) + out_2 = out_2.unsqueeze(3) + + out_3 = self.bottom_path(out_1) + out = out_1 * (1 + out_2 * out_3) + + return out + + +class PredictHeatmap(nn.Module): + """Predict the heatmap for an input feature. + + Args: + unit_channels (int): Number of input channels. + out_channels (int): Number of output channels. + out_shape (tuple): Shape of the output heatmaps. + use_prm (bool): Whether to use pose refine machine. Default: False. + norm_cfg (Config): Config to construct the norm layer. + Defaults to ``dict(type='BN')`` + """ + + def __init__(self, + unit_channels: int, + out_channels: int, + out_shape: tuple, + use_prm: bool = False, + norm_cfg: ConfigType = dict(type='BN')): + + super().__init__() + + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + self.unit_channels = unit_channels + self.out_channels = out_channels + self.out_shape = out_shape + self.use_prm = use_prm + if use_prm: + self.prm = PRM(out_channels, norm_cfg=norm_cfg) + self.conv_layers = nn.Sequential( + ConvModule( + unit_channels, + unit_channels, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + inplace=False), + ConvModule( + unit_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=None, + inplace=False)) + + def forward(self, feature: Tensor) -> Tensor: + """Forward the network. + + Args: + feature (Tensor): The input feature maps. + + Returns: + Tensor: output heatmaps. + """ + feature = self.conv_layers(feature) + output = nn.functional.interpolate( + feature, size=self.out_shape, mode='bilinear', align_corners=True) + if self.use_prm: + output = self.prm(output) + return output + + +@MODELS.register_module() +class MSPNHead(BaseHead): + """Multi-stage multi-unit heatmap head introduced in `Multi-Stage Pose + estimation Network (MSPN)`_ by Li et al (2019), and used by `Residual Steps + Networks (RSN)`_ by Cai et al (2020). The head consists of multiple stages + and each stage consists of multiple units. Each unit of each stage has some + conv layers. + + Args: + num_stages (int): Number of stages. + num_units (int): Number of units in each stage. + out_shape (tuple): The output shape of the output heatmaps. + unit_channels (int): Number of input channels. + out_channels (int): Number of output channels. + out_shape (tuple): Shape of the output heatmaps. + use_prm (bool): Whether to use pose refine machine (PRM). + Defaults to ``False``. + norm_cfg (Config): Config to construct the norm layer. + Defaults to ``dict(type='BN')`` + loss (Config | List[Config]): Config of the keypoint loss for + different stages and different units. + Defaults to use :class:`KeypointMSELoss`. + level_indices (Sequence[int]): The indices that specified the level + of target heatmaps. + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`MSPN`: https://arxiv.org/abs/1901.00148 + .. _`RSN`: https://arxiv.org/abs/2003.04030 + """ + _version = 2 + + def __init__(self, + num_stages: int = 4, + num_units: int = 4, + out_shape: tuple = (64, 48), + unit_channels: int = 256, + out_channels: int = 17, + use_prm: bool = False, + norm_cfg: ConfigType = dict(type='BN'), + level_indices: Sequence[int] = [], + loss: MultiConfig = dict( + type='KeypointMSELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + if init_cfg is None: + init_cfg = self.default_init_cfg + super().__init__(init_cfg) + + self.num_stages = num_stages + self.num_units = num_units + self.out_shape = out_shape + self.unit_channels = unit_channels + self.out_channels = out_channels + if len(level_indices) != num_stages * num_units: + raise ValueError( + f'The length of level_indices({len(level_indices)}) did not ' + f'match `num_stages`({num_stages}) * `num_units`({num_units})') + + self.level_indices = level_indices + + if isinstance(loss, list) and len(loss) != num_stages * num_units: + raise ValueError( + f'The length of loss_module({len(loss)}) did not match ' + f'`num_stages`({num_stages}) * `num_units`({num_units})') + + if isinstance(loss, list): + if len(loss) != num_stages * num_units: + raise ValueError( + f'The length of loss_module({len(loss)}) did not match ' + f'`num_stages`({num_stages}) * `num_units`({num_units})') + self.loss_module = nn.ModuleList( + MODELS.build(_loss) for _loss in loss) + else: + self.loss_module = MODELS.build(loss) + + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # Protect mutable default arguments + norm_cfg = copy.deepcopy(norm_cfg) + + self.predict_layers = nn.ModuleList([]) + for i in range(self.num_stages): + for j in range(self.num_units): + self.predict_layers.append( + PredictHeatmap( + unit_channels, + out_channels, + out_shape, + use_prm, + norm_cfg=norm_cfg)) + + @property + def default_init_cfg(self): + """Default config for weight initialization.""" + init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict(type='Normal', layer='Linear', std=0.01), + dict(type='Constant', layer='BatchNorm2d', val=1), + ] + return init_cfg + + def forward(self, feats: Sequence[Sequence[Tensor]]) -> List[Tensor]: + """Forward the network. The input is multi-stage multi-unit feature + maps and the output is a list of heatmaps from multiple stages. + + Args: + feats (Sequence[Sequence[Tensor]]): Feature maps from multiple + stages and units. + + Returns: + List[Tensor]: A list of output heatmaps from multiple stages + and units. + """ + out = [] + assert len(feats) == self.num_stages, ( + f'The length of feature maps did not match the ' + f'`num_stages` in {self.__class__.__name__}') + for feat in feats: + assert len(feat) == self.num_units, ( + f'The length of feature maps did not match the ' + f'`num_units` in {self.__class__.__name__}') + for f in feat: + assert f.shape[1] == self.unit_channels, ( + f'The number of feature map channels did not match the ' + f'`unit_channels` in {self.__class__.__name__}') + + for i in range(self.num_stages): + for j in range(self.num_units): + y = self.predict_layers[i * self.num_units + j](feats[i][j]) + out.append(y) + return out + + def predict(self, + feats: Union[MSMUFeatures, List[MSMUFeatures]], + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}) -> Predictions: + """Predict results from multi-stage feature maps. + + Args: + feats (Sequence[Sequence[Tensor]]): Multi-stage multi-unit + features (or multiple MSMU features for TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_labels`. + test_cfg (Config, optional): The testing/inference config + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + # multi-stage multi-unit batch heatmaps + if test_cfg.get('flip_test', False): + # TTA: flip test + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + _batch_heatmaps = self.forward(_feats)[-1] + _batch_heatmaps_flip = flip_heatmaps( + self.forward(_feats_flip)[-1], + flip_mode=test_cfg.get('flip_mode', 'heatmap'), + flip_indices=flip_indices, + shift_heatmap=test_cfg.get('shift_heatmap', False)) + batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5 + else: + msmu_batch_heatmaps = self.forward(feats) + batch_heatmaps = msmu_batch_heatmaps[-1] + + preds = self.decode(batch_heatmaps) + + if test_cfg.get('output_heatmaps', False): + pred_fields = [ + PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() + ] + return preds, pred_fields + else: + return preds + + def loss(self, + feats: MSMUFeatures, + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Note: + - batch_size: B + - num_output_heatmap_levels: L + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + - num_instances: N (usually 1 in topdown heatmap heads) + + Args: + feats (Sequence[Sequence[Tensor]]): Feature maps from multiple + stages and units + batch_data_samples (List[:obj:`PoseDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_labels` and `gt_fields`. + train_cfg (Config, optional): The training config + + Returns: + dict: A dictionary of loss components. + """ + # multi-stage multi-unit predict heatmaps + msmu_pred_heatmaps = self.forward(feats) + + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) # shape: [B*N, L, K] + + # calculate losses over multiple stages and multiple units + losses = dict() + for i in range(self.num_stages * self.num_units): + if isinstance(self.loss_module, nn.ModuleList): + # use different loss_module over different stages and units + loss_func = self.loss_module[i] + else: + # use the same loss_module over different stages and units + loss_func = self.loss_module + + # select `gt_heatmaps` and `keypoint_weights` for different level + # according to `self.level_indices` to calculate loss + gt_heatmaps = torch.stack([ + d.gt_fields[self.level_indices[i]].heatmaps + for d in batch_data_samples + ]) + loss_i = loss_func(msmu_pred_heatmaps[i], gt_heatmaps, + keypoint_weights[:, self.level_indices[i]]) + + if 'loss_kpt' not in losses: + losses['loss_kpt'] = loss_i + else: + losses['loss_kpt'] += loss_i + + # calculate accuracy + _, avg_acc, _ = pose_pck_accuracy( + output=to_numpy(msmu_pred_heatmaps[-1]), + target=to_numpy(gt_heatmaps), + mask=to_numpy(keypoint_weights[:, -1]) > 0) + + acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device) + losses.update(acc_pose=acc_pose) + + return losses diff --git a/mmpose/models/heads/heatmap_heads/vipnas_head.py b/mmpose/models/heads/heatmap_heads/vipnas_head.py new file mode 100644 index 0000000000000000000000000000000000000000..949ee95b096124a162f6d9719446fa80bd26a201 --- /dev/null +++ b/mmpose/models/heads/heatmap_heads/vipnas_head.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Union + +from mmcv.cnn import build_conv_layer, build_upsample_layer +from torch import nn + +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.typing import ConfigType, OptConfigType +from .heatmap_head import HeatmapHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class ViPNASHead(HeatmapHead): + """ViPNAS heatmap head introduced in `ViPNAS`_ by Xu et al (2021). The head + is composed of a few deconvolutional layers followed by a convolutional + layer to generate heatmaps from low-resolution feature maps. Specifically, + different from the :class: `HeatmapHead` introduced by `Simple Baselines`_, + the group numbers in the deconvolutional layers are elastic and thus can be + optimized by neural architecture search (NAS). + + Args: + in_channels (int | Sequence[int]): Number of channels in the input + feature map + out_channels (int): Number of channels in the output heatmap + deconv_out_channels (Sequence[int], optional): The output channel + number of each deconv layer. Defaults to ``(144, 144, 144)`` + deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size + of each deconv layer. Each element should be either an integer for + both height and width dimensions, or a tuple of two integers for + the height and the width dimension respectively.Defaults to + ``(4, 4, 4)`` + deconv_num_groups (Sequence[int], optional): The group number of each + deconv layer. Defaults to ``(16, 16, 16)`` + conv_out_channels (Sequence[int], optional): The output channel number + of each intermediate conv layer. ``None`` means no intermediate + conv layer between deconv layers and the final conv layer. + Defaults to ``None`` + conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size + of each intermediate conv layer. Defaults to ``None`` + final_layer (dict): Arguments of the final Conv2d layer. + Defaults to ``dict(kernel_size=1)`` + loss (Config): Config of the keypoint loss. Defaults to use + :class:`KeypointMSELoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`ViPNAS`: https://arxiv.org/abs/2105.10154 + .. _`Simple Baselines`: https://arxiv.org/abs/1804.06208 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + deconv_out_channels: OptIntSeq = (144, 144, 144), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + deconv_num_groups: OptIntSeq = (16, 16, 16), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + loss: ConfigType = dict( + type='KeypointMSELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super(HeatmapHead, self).__init__(init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + if deconv_out_channels: + if deconv_kernel_sizes is None or len(deconv_out_channels) != len( + deconv_kernel_sizes): + raise ValueError( + '"deconv_out_channels" and "deconv_kernel_sizes" should ' + 'be integer sequences with the same length. Got ' + f'mismatched lengths {deconv_out_channels} and ' + f'{deconv_kernel_sizes}') + if deconv_num_groups is None or len(deconv_out_channels) != len( + deconv_num_groups): + raise ValueError( + '"deconv_out_channels" and "deconv_num_groups" should ' + 'be integer sequences with the same length. Got ' + f'mismatched lengths {deconv_out_channels} and ' + f'{deconv_num_groups}') + + self.deconv_layers = self._make_deconv_layers( + in_channels=in_channels, + layer_out_channels=deconv_out_channels, + layer_kernel_sizes=deconv_kernel_sizes, + layer_groups=deconv_num_groups, + ) + in_channels = deconv_out_channels[-1] + else: + self.deconv_layers = nn.Identity() + + if conv_out_channels: + if conv_kernel_sizes is None or len(conv_out_channels) != len( + conv_kernel_sizes): + raise ValueError( + '"conv_out_channels" and "conv_kernel_sizes" should ' + 'be integer sequences with the same length. Got ' + f'mismatched lengths {conv_out_channels} and ' + f'{conv_kernel_sizes}') + + self.conv_layers = self._make_conv_layers( + in_channels=in_channels, + layer_out_channels=conv_out_channels, + layer_kernel_sizes=conv_kernel_sizes) + in_channels = conv_out_channels[-1] + else: + self.conv_layers = nn.Identity() + + if final_layer is not None: + cfg = dict( + type='Conv2d', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1) + cfg.update(final_layer) + self.final_layer = build_conv_layer(cfg) + else: + self.final_layer = nn.Identity() + + # Register the hook to automatically convert old version state dicts + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def _make_deconv_layers(self, in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int], + layer_groups: Sequence[int]) -> nn.Module: + """Create deconvolutional layers by given parameters.""" + + layers = [] + for out_channels, kernel_size, groups in zip(layer_out_channels, + layer_kernel_sizes, + layer_groups): + if kernel_size == 4: + padding = 1 + output_padding = 0 + elif kernel_size == 3: + padding = 1 + output_padding = 1 + elif kernel_size == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size} for' + 'deconvlutional layers in ' + f'{self.__class__.__name__}') + cfg = dict( + type='deconv', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=groups, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False) + layers.append(build_upsample_layer(cfg)) + layers.append(nn.BatchNorm2d(num_features=out_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) diff --git a/mmpose/models/heads/hybrid_heads/__init__.py b/mmpose/models/heads/hybrid_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55d5a211c1c0e17a61968ca5a266a797587f8c83 --- /dev/null +++ b/mmpose/models/heads/hybrid_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dekr_head import DEKRHead + +__all__ = [ + 'DEKRHead', +] diff --git a/mmpose/models/heads/hybrid_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/hybrid_heads/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1669ef83fbfcf02d9e985d2e9d1b491119bde95 Binary files /dev/null and b/mmpose/models/heads/hybrid_heads/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/heads/hybrid_heads/__pycache__/dekr_head.cpython-38.pyc b/mmpose/models/heads/hybrid_heads/__pycache__/dekr_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7c3940457ac8d16f34aae702d9b405bb89c96f8 Binary files /dev/null and b/mmpose/models/heads/hybrid_heads/__pycache__/dekr_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/hybrid_heads/dekr_head.py b/mmpose/models/heads/hybrid_heads/dekr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..41f7cfc4ce9f7cbb061c18ba14a4847a67a07ffc --- /dev/null +++ b/mmpose/models/heads/hybrid_heads/dekr_head.py @@ -0,0 +1,581 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple, Union + +import torch +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmengine.model import BaseModule, ModuleDict, Sequential +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmpose.evaluation.functional.nms import nearby_joints_nms +from mmpose.models.utils.tta import flip_heatmaps +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, Features, InstanceList, + OptConfigType, OptSampleList, Predictions) +from ...backbones.resnet import BasicBlock +from ..base_head import BaseHead + +try: + from mmcv.ops import DeformConv2d + has_mmcv_full = True +except (ImportError, ModuleNotFoundError): + has_mmcv_full = False + + +class AdaptiveActivationBlock(BaseModule): + """Adaptive activation convolution block. "Bottom-up human pose estimation + via disentangled keypoint regression", CVPR'2021. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + groups (int): Number of groups. Generally equal to the + number of joints. + norm_cfg (dict): Config for normalization layers. + act_cfg (dict): Config for activation layers. + """ + + def __init__(self, + in_channels, + out_channels, + groups=1, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(AdaptiveActivationBlock, self).__init__(init_cfg=init_cfg) + + assert in_channels % groups == 0 and out_channels % groups == 0 + self.groups = groups + + regular_matrix = torch.tensor([[-1, -1, -1, 0, 0, 0, 1, 1, 1], + [-1, 0, 1, -1, 0, 1, -1, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1]]) + self.register_buffer('regular_matrix', regular_matrix.float()) + + self.transform_matrix_conv = build_conv_layer( + dict(type='Conv2d'), + in_channels=in_channels, + out_channels=6 * groups, + kernel_size=3, + padding=1, + groups=groups, + bias=True) + + if has_mmcv_full: + self.adapt_conv = DeformConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + groups=groups, + deform_groups=groups) + else: + raise ImportError('Please install the full version of mmcv ' + 'to use `DeformConv2d`.') + + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + B, _, H, W = x.size() + residual = x + + affine_matrix = self.transform_matrix_conv(x) + affine_matrix = affine_matrix.permute(0, 2, 3, 1).contiguous() + affine_matrix = affine_matrix.view(B, H, W, self.groups, 2, 3) + offset = torch.matmul(affine_matrix, self.regular_matrix) + offset = offset.transpose(4, 5).reshape(B, H, W, self.groups * 18) + offset = offset.permute(0, 3, 1, 2).contiguous() + + x = self.adapt_conv(x, offset) + x = self.norm(x) + x = self.act(x + residual) + + return x + + +class RescoreNet(BaseModule): + """Rescore net used to predict the OKS score of predicted pose. We use the + off-the-shelf rescore net pretrained by authors of DEKR. + + Args: + in_channels (int): Input channels + norm_indexes (Tuple(int)): Indices of torso in skeleton + init_cfg (dict, optional): Initialization config dict + """ + + def __init__( + self, + in_channels, + norm_indexes, + init_cfg=None, + ): + super(RescoreNet, self).__init__(init_cfg=init_cfg) + + self.norm_indexes = norm_indexes + + hidden = 256 + + self.l1 = torch.nn.Linear(in_channels, hidden, bias=True) + self.l2 = torch.nn.Linear(hidden, hidden, bias=True) + self.l3 = torch.nn.Linear(hidden, 1, bias=True) + self.relu = torch.nn.ReLU() + + def make_feature(self, keypoints, keypoint_scores, skeleton): + """Combine original scores, joint distance and relative distance to + make feature. + + Args: + keypoints (torch.Tensor): predicetd keypoints + keypoint_scores (torch.Tensor): predicetd keypoint scores + skeleton (list(list(int))): joint links + + Returns: + torch.Tensor: feature for each instance + """ + joint_1, joint_2 = zip(*skeleton) + num_link = len(skeleton) + + joint_relate = (keypoints[:, joint_1] - + keypoints[:, joint_2])[:, :, :2] + joint_length = joint_relate.norm(dim=2) + + # To use the torso distance to normalize + normalize = (joint_length[:, self.norm_indexes[0]] + + joint_length[:, self.norm_indexes[1]]) / 2 + normalize = normalize.unsqueeze(1).expand(normalize.size(0), num_link) + normalize = normalize.clamp(min=1).contiguous() + + joint_length = joint_length / normalize[:, :] + joint_relate = joint_relate / normalize.unsqueeze(-1) + joint_relate = joint_relate.flatten(1) + + feature = torch.cat((joint_relate, joint_length, keypoint_scores), + dim=1).float() + return feature + + def forward(self, keypoints, keypoint_scores, skeleton): + feature = self.make_feature(keypoints, keypoint_scores, skeleton) + x = self.relu(self.l1(feature)) + x = self.relu(self.l2(x)) + x = self.l3(x) + return x.squeeze(1) + + +@MODELS.register_module() +class DEKRHead(BaseHead): + """DisEntangled Keypoint Regression head introduced in `Bottom-up human + pose estimation via disentangled keypoint regression`_ by Geng et al + (2021). The head is composed of a heatmap branch and a displacement branch. + + Args: + in_channels (int | Sequence[int]): Number of channels in the input + feature map + num_joints (int): Number of joints + num_heatmap_filters (int): Number of filters for heatmap branch. + Defaults to 32 + num_offset_filters_per_joint (int): Number of filters for each joint + in displacement branch. Defaults to 15 + heatmap_loss (Config): Config of the heatmap loss. Defaults to use + :class:`KeypointMSELoss` + displacement_loss (Config): Config of the displacement regression loss. + Defaults to use :class:`SoftWeightSmoothL1Loss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + rescore_cfg (Config, optional): The config for rescore net which + estimates OKS via predicted keypoints and keypoint scores. + Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`Bottom-up human pose estimation via disentangled keypoint regression`: + https://arxiv.org/abs/2104.02300 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + num_keypoints: int, + num_heatmap_filters: int = 32, + num_displacement_filters_per_keypoint: int = 15, + heatmap_loss: ConfigType = dict( + type='KeypointMSELoss', use_target_weight=True), + displacement_loss: ConfigType = dict( + type='SoftWeightSmoothL1Loss', + use_target_weight=True, + supervise_empty=False), + decoder: OptConfigType = None, + rescore_cfg: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_keypoints = num_keypoints + + # build heatmap branch + self.heatmap_conv_layers = self._make_heatmap_conv_layers( + in_channels=in_channels, + out_channels=1 + num_keypoints, + num_filters=num_heatmap_filters, + ) + + # build displacement branch + self.displacement_conv_layers = self._make_displacement_conv_layers( + in_channels=in_channels, + out_channels=2 * num_keypoints, + num_filters=num_keypoints * num_displacement_filters_per_keypoint, + groups=num_keypoints) + + # build losses + self.loss_module = ModuleDict( + dict( + heatmap=MODELS.build(heatmap_loss), + displacement=MODELS.build(displacement_loss), + )) + + # build decoder + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # build rescore net + if rescore_cfg is not None: + self.rescore_net = RescoreNet(**rescore_cfg) + else: + self.rescore_net = None + + # Register the hook to automatically convert old version state dicts + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + @property + def default_init_cfg(self): + init_cfg = [ + dict( + type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1) + ] + return init_cfg + + def _make_heatmap_conv_layers(self, in_channels: int, out_channels: int, + num_filters: int): + """Create convolutional layers of heatmap branch by given + parameters.""" + layers = [ + ConvModule( + in_channels=in_channels, + out_channels=num_filters, + kernel_size=1, + norm_cfg=dict(type='BN')), + BasicBlock(num_filters, num_filters), + build_conv_layer( + dict(type='Conv2d'), + in_channels=num_filters, + out_channels=out_channels, + kernel_size=1), + ] + + return Sequential(*layers) + + def _make_displacement_conv_layers(self, in_channels: int, + out_channels: int, num_filters: int, + groups: int): + """Create convolutional layers of displacement branch by given + parameters.""" + layers = [ + ConvModule( + in_channels=in_channels, + out_channels=num_filters, + kernel_size=1, + norm_cfg=dict(type='BN')), + AdaptiveActivationBlock(num_filters, num_filters, groups=groups), + AdaptiveActivationBlock(num_filters, num_filters, groups=groups), + build_conv_layer( + dict(type='Conv2d'), + in_channels=num_filters, + out_channels=out_channels, + kernel_size=1, + groups=groups) + ] + + return Sequential(*layers) + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is a tuple of heatmap and displacement. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tuple[Tensor]: output heatmap and displacement. + """ + x = feats[-1] + + heatmaps = self.heatmap_conv_layers(x) + displacements = self.displacement_conv_layers(x) + + return heatmaps, displacements + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + feats (Tuple[Tensor]): The multi-stage features + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + train_cfg (dict): The runtime config for training process. + Defaults to {} + + Returns: + dict: A dictionary of losses. + """ + pred_heatmaps, pred_displacements = self.forward(feats) + gt_heatmaps = torch.stack( + [d.gt_fields.heatmaps for d in batch_data_samples]) + heatmap_weights = torch.stack( + [d.gt_fields.heatmap_weights for d in batch_data_samples]) + gt_displacements = torch.stack( + [d.gt_fields.displacements for d in batch_data_samples]) + displacement_weights = torch.stack( + [d.gt_fields.displacement_weights for d in batch_data_samples]) + + if 'heatmap_mask' in batch_data_samples[0].gt_fields.keys(): + heatmap_mask = torch.stack( + [d.gt_fields.heatmap_mask for d in batch_data_samples]) + else: + heatmap_mask = None + + # calculate losses + losses = dict() + heatmap_loss = self.loss_module['heatmap'](pred_heatmaps, gt_heatmaps, + heatmap_weights, + heatmap_mask) + displacement_loss = self.loss_module['displacement']( + pred_displacements, gt_displacements, displacement_weights) + + losses.update({ + 'loss/heatmap': heatmap_loss, + 'loss/displacement': displacement_loss, + }) + + return losses + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-scale features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (1, h, w) + or (K+1, h, w) if keypoint heatmaps are predicted + - displacements (Tensor): The predicted displacement fields + in shape (K*2, h, w) + """ + + assert len(batch_data_samples) == 1, f'DEKRHead only supports ' \ + f'prediction with batch_size 1, but got {len(batch_data_samples)}' + + multiscale_test = test_cfg.get('multiscale_test', False) + flip_test = test_cfg.get('flip_test', False) + metainfo = batch_data_samples[0].metainfo + aug_scales = [1] + + if not multiscale_test: + feats = [feats] + else: + aug_scales = aug_scales + metainfo['aug_scales'] + + heatmaps, displacements = [], [] + for feat, s in zip(feats, aug_scales): + if flip_test: + assert isinstance(feat, list) and len(feat) == 2 + flip_indices = metainfo['flip_indices'] + _feat, _feat_flip = feat + _heatmaps, _displacements = self.forward(_feat) + _heatmaps_flip, _displacements_flip = self.forward(_feat_flip) + + _heatmaps_flip = flip_heatmaps( + _heatmaps_flip, + flip_mode='heatmap', + flip_indices=flip_indices + [len(flip_indices)], + shift_heatmap=test_cfg.get('shift_heatmap', False)) + _heatmaps = (_heatmaps + _heatmaps_flip) / 2.0 + + _displacements_flip = flip_heatmaps( + _displacements_flip, + flip_mode='offset', + flip_indices=flip_indices, + shift_heatmap=False) + + # this is a coordinate amendment. + x_scale_factor = s * ( + metainfo['input_size'][0] / _heatmaps.shape[-1]) + _displacements_flip[:, ::2] += (x_scale_factor - 1) / ( + x_scale_factor) + _displacements = (_displacements + _displacements_flip) / 2.0 + + else: + _heatmaps, _displacements = self.forward(feat) + + heatmaps.append(_heatmaps) + displacements.append(_displacements) + + preds = self.decode(heatmaps, displacements, test_cfg, metainfo) + + if test_cfg.get('output_heatmaps', False): + heatmaps = [hm.detach() for hm in heatmaps] + displacements = [dm.detach() for dm in displacements] + B = heatmaps[0].shape[0] + pred_fields = [] + for i in range(B): + pred_fields.append( + PixelData( + heatmaps=heatmaps[0][i], + displacements=displacements[0][i])) + return preds, pred_fields + else: + return preds + + def decode(self, + heatmaps: Tuple[Tensor], + displacements: Tuple[Tensor], + test_cfg: ConfigType = {}, + metainfo: dict = {}) -> InstanceList: + """Decode keypoints from outputs. + + Args: + heatmaps (Tuple[Tensor]): The output heatmaps inferred from one + image or multi-scale images. + displacements (Tuple[Tensor]): The output displacement fields + inferred from one image or multi-scale images. + test_cfg (dict): The runtime config for testing process. Defaults + to {} + metainfo (dict): The metainfo of test dataset. Defaults to {} + + Returns: + List[InstanceData]: A list of InstanceData, each contains the + decoded pose information of the instances of one data sample. + """ + + if self.decoder is None: + raise RuntimeError( + f'The decoder has not been set in {self.__class__.__name__}. ' + 'Please set the decoder configs in the init parameters to ' + 'enable head methods `head.predict()` and `head.decode()`') + + multiscale_test = test_cfg.get('multiscale_test', False) + skeleton = metainfo.get('skeleton_links', None) + + preds = [] + batch_size = heatmaps[0].shape[0] + + for b in range(batch_size): + if multiscale_test: + raise NotImplementedError + else: + keypoints, (root_scores, + keypoint_scores) = self.decoder.decode( + heatmaps[0][b], displacements[0][b]) + + # rescore each instance + if self.rescore_net is not None and skeleton and len( + keypoints) > 0: + instance_scores = self.rescore_net(keypoints, keypoint_scores, + skeleton) + instance_scores[torch.isnan(instance_scores)] = 0 + root_scores = root_scores * instance_scores + + # nms + keypoints, keypoint_scores = to_numpy((keypoints, keypoint_scores)) + scores = to_numpy(root_scores)[..., None] * keypoint_scores + if len(keypoints) > 0 and test_cfg.get('nms_dist_thr', 0) > 0: + kpts_db = [] + for i in range(len(keypoints)): + kpts_db.append( + dict(keypoints=keypoints[i], score=keypoint_scores[i])) + keep_instance_inds = nearby_joints_nms( + kpts_db, + test_cfg['nms_dist_thr'], + test_cfg.get('nms_joints_thr', None), + score_per_joint=True, + max_dets=test_cfg.get('max_num_people', 30)) + keypoints = keypoints[keep_instance_inds] + scores = scores[keep_instance_inds] + + # pack outputs + preds.append( + InstanceData(keypoints=keypoints, keypoint_scores=scores)) + + return preds + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to convert old-version state dict of + :class:`DEKRHead` (before MMPose v1.0.0) to a compatible format + of :class:`DEKRHead`. + + The hook will be automatically registered during initialization. + """ + version = local_meta.get('version', None) + if version and version >= self._version: + return + + # convert old-version state dict + keys = list(state_dict.keys()) + for k in keys: + if 'offset_conv_layer' in k: + v = state_dict.pop(k) + k = k.replace('offset_conv_layers', 'displacement_conv_layers') + if 'displacement_conv_layers.3.' in k: + # the source and target of displacement vectors are + # opposite between two versions. + v = -v + state_dict[k] = v + + if 'heatmap_conv_layers.2' in k: + # root heatmap is at the first/last channel of the + # heatmap tensor in MMPose v0.x/1.x, respectively. + v = state_dict.pop(k) + state_dict[k] = torch.cat((v[1:], v[:1])) + + if 'rescore_net' in k: + v = state_dict.pop(k) + k = k.replace('rescore_net', 'head.rescore_net') + state_dict[k] = v diff --git a/mmpose/models/heads/regression_heads/__init__.py b/mmpose/models/heads/regression_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a5027b1b6f33f488c268c3bde0142681f4ac4c --- /dev/null +++ b/mmpose/models/heads/regression_heads/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dsnt_head import DSNTHead +from .integral_regression_head import IntegralRegressionHead +from .regression_head import RegressionHead +from .rle_head import RLEHead + +__all__ = [ + 'RegressionHead', + 'IntegralRegressionHead', + 'DSNTHead', + 'RLEHead', +] diff --git a/mmpose/models/heads/regression_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04fcd7346069d3e36bccf6e1c0720658b8408df Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/heads/regression_heads/__pycache__/dsnt_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/dsnt_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e22015c60eb75b8cc8bd4e65ede26098d270fa41 Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/dsnt_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/regression_heads/__pycache__/integral_regression_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/integral_regression_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088d83163c4b8d9d7d0b482304933bc11a45d4ed Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/integral_regression_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/regression_heads/__pycache__/regression_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/regression_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6635b75277ea58e5011239676edd866eab9adc32 Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/regression_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/regression_heads/__pycache__/rle_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/rle_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31cd606ae22b3947142260f37a648e08de0e2684 Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/rle_head.cpython-38.pyc differ diff --git a/mmpose/models/heads/regression_heads/dsnt_head.py b/mmpose/models/heads/regression_heads/dsnt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd49e385db31c996de086419285e2f5fa7748b3 --- /dev/null +++ b/mmpose/models/heads/regression_heads/dsnt_head.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from mmengine.logging import MessageHub +from torch import Tensor + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.registry import MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import ConfigType, OptConfigType, OptSampleList +from .integral_regression_head import IntegralRegressionHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class DSNTHead(IntegralRegressionHead): + """Top-down integral regression head introduced in `DSNT`_ by Nibali et + al(2018). The head contains a differentiable spatial to numerical transform + (DSNT) layer that do soft-argmax operation on the predicted heatmaps to + regress the coordinates. + + This head is used for algorithms that require supervision of heatmaps + in `DSNT` approach. + + Args: + in_channels (int | sequence[int]): Number of input channels + in_featuremap_size (int | sequence[int]): Size of input feature map + num_joints (int): Number of joints + lambda_t (int): Discard heatmap-based loss when current + epoch > lambda_t. Defaults to -1. + debias (bool): Whether to remove the bias of Integral Pose Regression. + see `Removing the Bias of Integral Pose Regression`_ by Gu et al + (2021). Defaults to ``False``. + beta (float): A smoothing parameter in softmax. Defaults to ``1.0``. + deconv_out_channels (sequence[int]): The output channel number of each + deconv layer. Defaults to ``(256, 256, 256)`` + deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size + of each deconv layer. Each element should be either an integer for + both height and width dimensions, or a tuple of two integers for + the height and the width dimension respectively.Defaults to + ``(4, 4, 4)`` + conv_out_channels (sequence[int], optional): The output channel number + of each intermediate conv layer. ``None`` means no intermediate + conv layer between deconv layers and the final conv layer. + Defaults to ``None`` + conv_kernel_sizes (sequence[int | tuple], optional): The kernel size + of each intermediate conv layer. Defaults to ``None`` + final_layer (dict): Arguments of the final Conv2d layer. + Defaults to ``dict(kernel_size=1)`` + loss (Config): Config for keypoint loss. Defaults to use + :class:`DSNTLoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`DSNT`: https://arxiv.org/abs/1801.07372 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + in_featuremap_size: Tuple[int, int], + num_joints: int, + lambda_t: int = -1, + debias: bool = False, + beta: float = 1.0, + deconv_out_channels: OptIntSeq = (256, 256, 256), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + loss: ConfigType = dict( + type='MultipleLossWrapper', + losses=[ + dict(type='SmoothL1Loss', use_target_weight=True), + dict(type='JSDiscretLoss', use_target_weight=True) + ]), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + super().__init__( + in_channels=in_channels, + in_featuremap_size=in_featuremap_size, + num_joints=num_joints, + debias=debias, + beta=beta, + deconv_out_channels=deconv_out_channels, + deconv_kernel_sizes=deconv_kernel_sizes, + conv_out_channels=conv_out_channels, + conv_kernel_sizes=conv_kernel_sizes, + final_layer=final_layer, + loss=loss, + decoder=decoder, + init_cfg=init_cfg) + + self.lambda_t = lambda_t + + def loss(self, + inputs: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_coords, pred_heatmaps = self.forward(inputs) + keypoint_labels = torch.cat( + [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + gt_heatmaps = torch.stack( + [d.gt_fields.heatmaps for d in batch_data_samples]) + + input_list = [pred_coords, pred_heatmaps] + target_list = [keypoint_labels, gt_heatmaps] + # calculate losses + losses = dict() + + loss_list = self.loss_module(input_list, target_list, keypoint_weights) + + loss = loss_list[0] + loss_list[1] + + if self.lambda_t > 0: + mh = MessageHub.get_current_instance() + cur_epoch = mh.get_info('epoch') + if cur_epoch >= self.lambda_t: + loss = loss_list[0] + + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_coords), + gt=to_numpy(keypoint_labels), + mask=to_numpy(keypoint_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32)) + + acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device) + losses.update(acc_pose=acc_pose) + + return losses diff --git a/mmpose/models/heads/regression_heads/integral_regression_head.py b/mmpose/models/heads/regression_heads/integral_regression_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9046d94ad4318a19a3037f839ee054a445c80c68 --- /dev/null +++ b/mmpose/models/heads/regression_heads/integral_regression_head.py @@ -0,0 +1,339 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer +from mmengine.structures import PixelData +from torch import Tensor, nn + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.models.utils.tta import flip_coordinates, flip_heatmaps +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, + Predictions) +from .. import HeatmapHead +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class IntegralRegressionHead(BaseHead): + """Top-down integral regression head introduced in `IPR`_ by Xiao et + al(2018). The head contains a differentiable spatial to numerical transform + (DSNT) layer that do soft-argmax operation on the predicted heatmaps to + regress the coordinates. + + This head is used for algorithms that only supervise the coordinates. + + Args: + in_channels (int | sequence[int]): Number of input channels + in_featuremap_size (int | sequence[int]): Size of input feature map + num_joints (int): Number of joints + debias (bool): Whether to remove the bias of Integral Pose Regression. + see `Removing the Bias of Integral Pose Regression`_ by Gu et al + (2021). Defaults to ``False``. + beta (float): A smoothing parameter in softmax. Defaults to ``1.0``. + deconv_out_channels (sequence[int]): The output channel number of each + deconv layer. Defaults to ``(256, 256, 256)`` + deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size + of each deconv layer. Each element should be either an integer for + both height and width dimensions, or a tuple of two integers for + the height and the width dimension respectively.Defaults to + ``(4, 4, 4)`` + conv_out_channels (sequence[int], optional): The output channel number + of each intermediate conv layer. ``None`` means no intermediate + conv layer between deconv layers and the final conv layer. + Defaults to ``None`` + conv_kernel_sizes (sequence[int | tuple], optional): The kernel size + of each intermediate conv layer. Defaults to ``None`` + final_layer (dict): Arguments of the final Conv2d layer. + Defaults to ``dict(kernel_size=1)`` + loss (Config): Config for keypoint loss. Defaults to use + :class:`SmoothL1Loss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`IPR`: https://arxiv.org/abs/1711.08229 + .. _`Debias`: + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + in_featuremap_size: Tuple[int, int], + num_joints: int, + debias: bool = False, + beta: float = 1.0, + deconv_out_channels: OptIntSeq = (256, 256, 256), + deconv_kernel_sizes: OptIntSeq = (4, 4, 4), + conv_out_channels: OptIntSeq = None, + conv_kernel_sizes: OptIntSeq = None, + final_layer: dict = dict(kernel_size=1), + loss: ConfigType = dict( + type='SmoothL1Loss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_joints = num_joints + self.debias = debias + self.beta = beta + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + num_deconv = len(deconv_out_channels) if deconv_out_channels else 0 + if num_deconv != 0: + + self.heatmap_size = tuple( + [s * (2**num_deconv) for s in in_featuremap_size]) + + # deconv layers + 1x1 conv + self.simplebaseline_head = HeatmapHead( + in_channels=in_channels, + out_channels=num_joints, + deconv_out_channels=deconv_out_channels, + deconv_kernel_sizes=deconv_kernel_sizes, + conv_out_channels=conv_out_channels, + conv_kernel_sizes=conv_kernel_sizes, + final_layer=final_layer) + + if final_layer is not None: + in_channels = num_joints + else: + in_channels = deconv_out_channels[-1] + + else: + self.simplebaseline_head = None + + if final_layer is not None: + cfg = dict( + type='Conv2d', + in_channels=in_channels, + out_channels=num_joints, + kernel_size=1) + cfg.update(final_layer) + self.final_layer = build_conv_layer(cfg) + else: + self.final_layer = None + + self.heatmap_size = in_featuremap_size + + if isinstance(in_channels, list): + raise ValueError( + f'{self.__class__.__name__} does not support selecting ' + 'multiple input features.') + + W, H = self.heatmap_size + self.linspace_x = torch.arange(0.0, 1.0 * W, 1).reshape(1, 1, 1, W) / W + self.linspace_y = torch.arange(0.0, 1.0 * H, 1).reshape(1, 1, H, 1) / H + + self.linspace_x = nn.Parameter(self.linspace_x, requires_grad=False) + self.linspace_y = nn.Parameter(self.linspace_y, requires_grad=False) + + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def _linear_expectation(self, heatmaps: Tensor, + linspace: Tensor) -> Tensor: + """Calculate linear expectation.""" + + B, N, _, _ = heatmaps.shape + heatmaps = heatmaps.mul(linspace).reshape(B, N, -1) + expectation = torch.sum(heatmaps, dim=2, keepdim=True) + + return expectation + + def _flat_softmax(self, featmaps: Tensor) -> Tensor: + """Use Softmax to normalize the featmaps in depthwise.""" + + _, N, H, W = featmaps.shape + + featmaps = featmaps.reshape(-1, N, H * W) + heatmaps = F.softmax(featmaps, dim=2) + + return heatmaps.reshape(-1, N, H, W) + + def forward(self, feats: Tuple[Tensor]) -> Union[Tensor, Tuple[Tensor]]: + """Forward the network. The input is multi scale feature maps and the + output is the coordinates. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output coordinates(and sigmas[optional]). + """ + if self.simplebaseline_head is None: + feats = feats[-1] + if self.final_layer is not None: + feats = self.final_layer(feats) + else: + feats = self.simplebaseline_head(feats) + + heatmaps = self._flat_softmax(feats * self.beta) + + pred_x = self._linear_expectation(heatmaps, self.linspace_x) + pred_y = self._linear_expectation(heatmaps, self.linspace_y) + + if self.debias: + B, N, H, W = feats.shape + C = feats.reshape(B, N, H * W).exp().sum(dim=2).reshape(B, N, 1) + pred_x = C / (C - 1) * (pred_x - 1 / (2 * C)) + pred_y = C / (C - 1) * (pred_y - 1 / (2 * C)) + + coords = torch.cat([pred_x, pred_y], dim=-1) + return coords, heatmaps + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + ``test_cfg['output_heatmap']==True``, return both pose and heatmap + prediction; otherwise only return the pose prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + input_size = batch_data_samples[0].metainfo['input_size'] + _feats, _feats_flip = feats + + _batch_coords, _batch_heatmaps = self.forward(_feats) + + _batch_coords_flip, _batch_heatmaps_flip = self.forward( + _feats_flip) + _batch_coords_flip = flip_coordinates( + _batch_coords_flip, + flip_indices=flip_indices, + shift_coords=test_cfg.get('shift_coords', True), + input_size=input_size) + _batch_heatmaps_flip = flip_heatmaps( + _batch_heatmaps_flip, + flip_mode='heatmap', + flip_indices=flip_indices, + shift_heatmap=test_cfg.get('shift_heatmap', False)) + + batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 + batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5 + else: + batch_coords, batch_heatmaps = self.forward(feats) # (B, K, D) + + batch_coords.unsqueeze_(dim=1) # (B, N, K, D) + preds = self.decode(batch_coords) + + if test_cfg.get('output_heatmaps', False): + pred_fields = [ + PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() + ] + return preds, pred_fields + else: + return preds + + def loss(self, + inputs: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_coords, _ = self.forward(inputs) + keypoint_labels = torch.cat( + [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + # calculate losses + losses = dict() + + # TODO: multi-loss calculation + loss = self.loss_module(pred_coords, keypoint_labels, keypoint_weights) + + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_coords), + gt=to_numpy(keypoint_labels), + mask=to_numpy(keypoint_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32)) + + acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device) + losses.update(acc_pose=acc_pose) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to load weights of deconv layers from + :class:`HeatmapHead` into `simplebaseline_head`. + + The hook will be automatically registered during initialization. + """ + + # convert old-version state dict + keys = list(state_dict.keys()) + for _k in keys: + if not _k.startswith(prefix): + continue + v = state_dict.pop(_k) + k = _k.lstrip(prefix) + + k_new = _k + k_parts = k.split('.') + if self.simplebaseline_head is not None: + if k_parts[0] == 'conv_layers': + k_new = ( + prefix + 'simplebaseline_head.deconv_layers.' + + '.'.join(k_parts[1:])) + elif k_parts[0] == 'final_layer': + k_new = prefix + 'simplebaseline_head.' + k + + state_dict[k_new] = v diff --git a/mmpose/models/heads/regression_heads/regression_head.py b/mmpose/models/heads/regression_heads/regression_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff73aa6ef1bed93e8985d9be20f3c94355d8c21 --- /dev/null +++ b/mmpose/models/heads/regression_heads/regression_head.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor, nn + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.models.utils.tta import flip_coordinates +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, + Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class RegressionHead(BaseHead): + """Top-down regression head introduced in `Deeppose`_ by Toshev et al + (2014). The head is composed of fully-connected layers to predict the + coordinates directly. + + Args: + in_channels (int | sequence[int]): Number of input channels + num_joints (int): Number of joints + loss (Config): Config for keypoint loss. Defaults to use + :class:`SmoothL1Loss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`Deeppose`: https://arxiv.org/abs/1312.4659 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + num_joints: int, + loss: ConfigType = dict( + type='SmoothL1Loss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_joints = num_joints + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # Define fully-connected layers + self.fc = nn.Linear(in_channels, self.num_joints * 2) + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the coordinates. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output coordinates(and sigmas[optional]). + """ + x = feats[-1] + + x = torch.flatten(x, 1) + x = self.fc(x) + + return x.reshape(-1, self.num_joints, 2) + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from outputs.""" + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + input_size = batch_data_samples[0].metainfo['input_size'] + _feats, _feats_flip = feats + + _batch_coords = self.forward(_feats) + _batch_coords_flip = flip_coordinates( + self.forward(_feats_flip), + flip_indices=flip_indices, + shift_coords=test_cfg.get('shift_coords', True), + input_size=input_size) + batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 + else: + batch_coords = self.forward(feats) # (B, K, D) + + batch_coords.unsqueeze_(dim=1) # (B, N, K, D) + preds = self.decode(batch_coords) + + return preds + + def loss(self, + inputs: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_outputs = self.forward(inputs) + + keypoint_labels = torch.cat( + [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_outputs, keypoint_labels, + keypoint_weights.unsqueeze(-1)) + + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_outputs), + gt=to_numpy(keypoint_labels), + mask=to_numpy(keypoint_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_outputs.size(0), 2), dtype=np.float32)) + + acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device) + losses.update(acc_pose=acc_pose) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg diff --git a/mmpose/models/heads/regression_heads/rle_head.py b/mmpose/models/heads/regression_heads/rle_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ef62d7d9acbe1235c47d8acbef47e38ab1be6348 --- /dev/null +++ b/mmpose/models/heads/regression_heads/rle_head.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor, nn + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.models.utils.tta import flip_coordinates +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, + Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class RLEHead(BaseHead): + """Top-down regression head introduced in `RLE`_ by Li et al(2021). The + head is composed of fully-connected layers to predict the coordinates and + sigma(the variance of the coordinates) together. + + Args: + in_channels (int | sequence[int]): Number of input channels + num_joints (int): Number of joints + loss (Config): Config for keypoint loss. Defaults to use + :class:`RLELoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`RLE`: https://arxiv.org/abs/2107.11291 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + num_joints: int, + loss: ConfigType = dict( + type='RLELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_joints = num_joints + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # Define fully-connected layers + self.fc = nn.Linear(in_channels, self.num_joints * 4) + + # Register the hook to automatically convert old version state dicts + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the coordinates. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output coordinates(and sigmas[optional]). + """ + x = feats[-1] + + x = torch.flatten(x, 1) + x = self.fc(x) + + return x.reshape(-1, self.num_joints, 4) + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from outputs.""" + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + input_size = batch_data_samples[0].metainfo['input_size'] + + _feats, _feats_flip = feats + + _batch_coords = self.forward(_feats) + _batch_coords[..., 2:] = _batch_coords[..., 2:].sigmoid() + + _batch_coords_flip = flip_coordinates( + self.forward(_feats_flip), + flip_indices=flip_indices, + shift_coords=test_cfg.get('shift_coords', True), + input_size=input_size) + _batch_coords_flip[..., 2:] = _batch_coords_flip[..., 2:].sigmoid() + + batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 + else: + batch_coords = self.forward(feats) # (B, K, D) + batch_coords[..., 2:] = batch_coords[..., 2:].sigmoid() + + batch_coords.unsqueeze_(dim=1) # (B, N, K, D) + preds = self.decode(batch_coords) + + return preds + + def loss(self, + inputs: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_outputs = self.forward(inputs) + + keypoint_labels = torch.cat( + [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + pred_coords = pred_outputs[:, :, :2] + pred_sigma = pred_outputs[:, :, 2:4] + + # calculate losses + losses = dict() + loss = self.loss_module(pred_coords, pred_sigma, keypoint_labels, + keypoint_weights.unsqueeze(-1)) + + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_coords), + gt=to_numpy(keypoint_labels), + mask=to_numpy(keypoint_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32)) + + acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device) + losses.update(acc_pose=acc_pose) + + return losses + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to convert old-version state dict of + :class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a + compatible format of :class:`HeatmapHead`. + + The hook will be automatically registered during initialization. + """ + + version = local_meta.get('version', None) + if version and version >= self._version: + return + + # convert old-version state dict + keys = list(state_dict.keys()) + for _k in keys: + v = state_dict.pop(_k) + k = _k.lstrip(prefix) + # In old version, "loss" includes the instances of loss, + # now it should be renamed "loss_module" + k_parts = k.split('.') + if k_parts[0] == 'loss': + # loss.xxx -> loss_module.xxx + k_new = prefix + 'loss_module.' + '.'.join(k_parts[1:]) + else: + k_new = _k + + state_dict[k_new] = v + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg diff --git a/mmpose/models/losses/__init__.py b/mmpose/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f21071e15699c42ca21f49e62fe0fb9f2869a68e --- /dev/null +++ b/mmpose/models/losses/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ae_loss import AssociativeEmbeddingLoss +from .classification_loss import BCELoss, JSDiscretLoss, KLDiscretLoss +from .heatmap_loss import (AdaptiveWingLoss, KeypointMSELoss, + KeypointOHKMMSELoss) +from .loss_wrappers import CombinedLoss, MultipleLossWrapper +from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss, + SemiSupervisionLoss, SmoothL1Loss, + SoftWeightSmoothL1Loss, SoftWingLoss, WingLoss) + +__all__ = [ + 'KeypointMSELoss', 'KeypointOHKMMSELoss', 'SmoothL1Loss', 'WingLoss', + 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss', + 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss', + 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', 'CombinedLoss', + 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss' +] diff --git a/mmpose/models/losses/__pycache__/__init__.cpython-38.pyc b/mmpose/models/losses/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c60a281882de129a8e5b5ab33bc9d4d34e7ca8 Binary files /dev/null and b/mmpose/models/losses/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/losses/__pycache__/ae_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/ae_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53d5c9ad4ef2e04d19925d834f442240e3cadd82 Binary files /dev/null and b/mmpose/models/losses/__pycache__/ae_loss.cpython-38.pyc differ diff --git a/mmpose/models/losses/__pycache__/classification_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/classification_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea4cb8ac587ce0210fed22c097e51f079fd7e250 Binary files /dev/null and b/mmpose/models/losses/__pycache__/classification_loss.cpython-38.pyc differ diff --git a/mmpose/models/losses/__pycache__/heatmap_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/heatmap_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..023ffe15dd8dc4bfe72cfcb1d236b4c953055a1f Binary files /dev/null and b/mmpose/models/losses/__pycache__/heatmap_loss.cpython-38.pyc differ diff --git a/mmpose/models/losses/__pycache__/loss_wrappers.cpython-38.pyc b/mmpose/models/losses/__pycache__/loss_wrappers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..198ff7f4b4baaf276f3cf2215f7e432eefb666ac Binary files /dev/null and b/mmpose/models/losses/__pycache__/loss_wrappers.cpython-38.pyc differ diff --git a/mmpose/models/losses/__pycache__/regression_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/regression_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a5c14f73e6f0e97fe8985e7e218624c8d0b524 Binary files /dev/null and b/mmpose/models/losses/__pycache__/regression_loss.cpython-38.pyc differ diff --git a/mmpose/models/losses/ae_loss.py b/mmpose/models/losses/ae_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1e08181beaf835238596d95fe509b122c64b3d --- /dev/null +++ b/mmpose/models/losses/ae_loss.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class AssociativeEmbeddingLoss(nn.Module): + """Associative Embedding loss. + + Details can be found in + `Associative Embedding `_ + + Note: + + - batch size: B + - instance number: N + - keypoint number: K + - keypoint dimension: D + - embedding tag dimension: L + - heatmap size: [W, H] + + Args: + loss_weight (float): Weight of the loss. Defaults to 1.0 + push_loss_factor (float): A factor that controls the weight between + the push loss and the pull loss. Defaults to 0.5 + """ + + def __init__(self, + loss_weight: float = 1.0, + push_loss_factor: float = 0.5) -> None: + super().__init__() + self.loss_weight = loss_weight + self.push_loss_factor = push_loss_factor + + def _ae_loss_per_image(self, tags: Tensor, keypoint_indices: Tensor): + """Compute associative embedding loss for one image. + + Args: + tags (Tensor): Tagging heatmaps in shape (K*L, H, W) + keypoint_indices (Tensor): Ground-truth keypint position indices + in shape (N, K, 2) + """ + K = keypoint_indices.shape[1] + C, H, W = tags.shape + L = C // K + + tags = tags.view(L, K, H * W) + instance_tags = [] + instance_kpt_tags = [] + + for keypoint_indices_n in keypoint_indices: + _kpt_tags = [] + for k in range(K): + if keypoint_indices_n[k, 1]: + _kpt_tags.append(tags[:, k, keypoint_indices_n[k, 0]]) + + if _kpt_tags: + kpt_tags = torch.stack(_kpt_tags) + instance_kpt_tags.append(kpt_tags) + instance_tags.append(kpt_tags.mean(dim=0)) + + N = len(instance_kpt_tags) # number of instances with valid keypoints + + if N == 0: + pull_loss = tags.new_zeros(size=(), requires_grad=True) + push_loss = tags.new_zeros(size=(), requires_grad=True) + else: + pull_loss = sum( + F.mse_loss(_kpt_tags, _tag.expand_as(_kpt_tags)) + for (_kpt_tags, _tag) in zip(instance_kpt_tags, instance_tags)) + + if N == 1: + push_loss = tags.new_zeros(size=(), requires_grad=True) + else: + tag_mat = torch.stack(instance_tags) # (N, L) + diff = tag_mat[None] - tag_mat[:, None] # (N, N, L) + push_loss = torch.sum(torch.exp(-diff.pow(2))) + + # normalization + eps = 1e-6 + pull_loss = pull_loss / (N + eps) + push_loss = push_loss / ((N - 1) * N + eps) + + return pull_loss, push_loss + + def forward(self, tags: Tensor, keypoint_indices: Union[List[Tensor], + Tensor]): + """Compute associative embedding loss on a batch of data. + + Args: + tags (Tensor): Tagging heatmaps in shape (B, L*K, H, W) + keypoint_indices (Tensor|List[Tensor]): Ground-truth keypint + position indices represented by a Tensor in shape + (B, N, K, 2), or a list of B Tensors in shape (N_i, K, 2) + Each keypoint's index is represented as [i, v], where i is the + position index in the heatmap (:math:`i=y*w+x`) and v is the + visibility + + Returns: + tuple: + - pull_loss (Tensor) + - push_loss (Tensor) + """ + + assert tags.shape[0] == len(keypoint_indices) + + pull_loss = 0. + push_loss = 0. + + for i in range(tags.shape[0]): + _pull, _push = self._ae_loss_per_image(tags[i], + keypoint_indices[i]) + pull_loss += _pull * self.loss_weight + push_loss += _push * self.loss_weight * self.push_loss_factor + + return pull_loss, push_loss diff --git a/mmpose/models/losses/classification_loss.py b/mmpose/models/losses/classification_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3bdf502b7f973d8fe9c2217faa7e6ad3a5c849 --- /dev/null +++ b/mmpose/models/losses/classification_loss.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class BCELoss(nn.Module): + """Binary Cross Entropy loss. + + Args: + use_target_weight (bool): Option to use weighted loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, use_target_weight=False, loss_weight=1.): + super().__init__() + self.criterion = F.binary_cross_entropy + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_labels: K + + Args: + output (torch.Tensor[N, K]): Output classification. + target (torch.Tensor[N, K]): Target classification. + target_weight (torch.Tensor[N, K] or torch.Tensor[N]): + Weights across different labels. + """ + + if self.use_target_weight: + assert target_weight is not None + loss = self.criterion(output, target, reduction='none') + if target_weight.dim() == 1: + target_weight = target_weight[:, None] + loss = (loss * target_weight).mean() + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class JSDiscretLoss(nn.Module): + """Discrete JS Divergence loss for DSNT with Gaussian Heatmap. + + Modified from `the official implementation + `_. + + Args: + use_target_weight (bool): Option to use weighted loss. + Different joint types may have different target weights. + size_average (bool): Option to average the loss by the batch_size. + """ + + def __init__( + self, + use_target_weight=True, + size_average: bool = True, + ): + super(JSDiscretLoss, self).__init__() + self.use_target_weight = use_target_weight + self.size_average = size_average + self.kl_loss = nn.KLDivLoss(reduction='none') + + def kl(self, p, q): + """Kullback-Leibler Divergence.""" + + eps = 1e-24 + kl_values = self.kl_loss((q + eps).log(), p) + return kl_values + + def js(self, pred_hm, gt_hm): + """Jensen-Shannon Divergence.""" + + m = 0.5 * (pred_hm + gt_hm) + js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m)) + return js_values + + def forward(self, pred_hm, gt_hm, target_weight=None): + """Forward function. + + Args: + pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps. + gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps. + target_weight (torch.Tensor[N, K] or torch.Tensor[N]): + Weights across different labels. + + Returns: + torch.Tensor: Loss value. + """ + + if self.use_target_weight: + assert target_weight is not None + assert pred_hm.ndim >= target_weight.ndim + + for i in range(pred_hm.ndim - target_weight.ndim): + target_weight = target_weight.unsqueeze(-1) + + loss = self.js(pred_hm * target_weight, gt_hm * target_weight) + else: + loss = self.js(pred_hm, gt_hm) + + if self.size_average: + loss /= len(gt_hm) + + return loss.sum() + + +@MODELS.register_module() +class KLDiscretLoss(nn.Module): + """Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing. + Modified from `the official implementation. + + `_. + Args: + beta (float): Temperature factor of Softmax. + label_softmax (bool): Whether to use Softmax on labels. + use_target_weight (bool): Option to use weighted loss. + Different joint types may have different target weights. + """ + + def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True): + super(KLDiscretLoss, self).__init__() + self.beta = beta + self.label_softmax = label_softmax + self.use_target_weight = use_target_weight + + self.log_softmax = nn.LogSoftmax(dim=1) + self.kl_loss = nn.KLDivLoss(reduction='none') + + def criterion(self, dec_outs, labels): + """Criterion function.""" + log_pt = self.log_softmax(dec_outs * self.beta) + if self.label_softmax: + labels = F.softmax(labels * self.beta, dim=1) + loss = torch.mean(self.kl_loss(log_pt, labels), dim=1) + return loss + + def forward(self, pred_simcc, gt_simcc, target_weight): + """Forward function. + + Args: + pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of + x-axis and y-axis. + gt_simcc (Tuple[Tensor, Tensor]): Target representations. + target_weight (torch.Tensor[N, K] or torch.Tensor[N]): + Weights across different labels. + """ + num_joints = pred_simcc[0].size(1) + loss = 0 + + if self.use_target_weight: + weight = target_weight.reshape(-1) + else: + weight = 1. + + for pred, target in zip(pred_simcc, gt_simcc): + pred = pred.reshape(-1, pred.size(-1)) + target = target.reshape(-1, target.size(-1)) + + loss += self.criterion(pred, target).mul(weight).sum() + + return loss / num_joints + + +@MODELS.register_module() +class InfoNCELoss(nn.Module): + """InfoNCE loss for training a discriminative representation space with a + contrastive manner. + + `Representation Learning with Contrastive Predictive Coding + arXiv: `_. + + Args: + temperature (float, optional): The temperature to use in the softmax + function. Higher temperatures lead to softer probability + distributions. Defaults to 1.0. + loss_weight (float, optional): The weight to apply to the loss. + Defaults to 1.0. + """ + + def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None: + super(InfoNCELoss, self).__init__() + assert temperature > 0, f'the argument `temperature` must be ' \ + f'positive, but got {temperature}' + self.temp = temperature + self.loss_weight = loss_weight + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Computes the InfoNCE loss. + + Args: + features (Tensor): A tensor containing the feature + representations of different samples. + + Returns: + Tensor: A tensor of shape (1,) containing the InfoNCE loss. + """ + n = features.size(0) + features_norm = F.normalize(features, dim=1) + logits = features_norm.mm(features_norm.t()) / self.temp + targets = torch.arange(n, dtype=torch.long, device=features.device) + loss = F.cross_entropy(logits, targets, reduction='sum') + return loss * self.loss_weight diff --git a/mmpose/models/losses/heatmap_loss.py b/mmpose/models/losses/heatmap_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a1051494683c817d7d14a73f2ceebed67834d64b --- /dev/null +++ b/mmpose/models/losses/heatmap_loss.py @@ -0,0 +1,455 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class KeypointMSELoss(nn.Module): + """MSE loss for heatmaps. + + Args: + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + Defaults to ``False`` + skip_empty_channel (bool): If ``True``, heatmap channels with no + non-zero value (which means no visible ground-truth keypoint + in the image) will not be used to calculate the loss. Defaults to + ``False`` + loss_weight (float): Weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + use_target_weight: bool = False, + skip_empty_channel: bool = False, + loss_weight: float = 1.): + super().__init__() + self.use_target_weight = use_target_weight + self.skip_empty_channel = skip_empty_channel + self.loss_weight = loss_weight + + def forward(self, + output: Tensor, + target: Tensor, + target_weights: Optional[Tensor] = None, + mask: Optional[Tensor] = None) -> Tensor: + """Forward function of loss. + + Note: + - batch_size: B + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (Tensor): The output heatmaps with shape [B, K, H, W] + target (Tensor): The target heatmaps with shape [B, K, H, W] + target_weights (Tensor, optional): The target weights of differet + keypoints, with shape [B, K] (keypoint-wise) or + [B, K, H, W] (pixel-wise). + mask (Tensor, optional): The masks of valid heatmap pixels in + shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will + be applied. Defaults to ``None`` + + Returns: + Tensor: The calculated loss. + """ + + _mask = self._get_mask(target, target_weights, mask) + if _mask is None: + loss = F.mse_loss(output, target) + else: + _loss = F.mse_loss(output, target, reduction='none') + loss = (_loss * _mask).mean() + + return loss * self.loss_weight + + def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], + mask: Optional[Tensor]) -> Optional[Tensor]: + """Generate the heatmap mask w.r.t. the given mask, target weight and + `skip_empty_channel` setting. + + Returns: + Tensor: The mask in shape (B, K, *) or ``None`` if no mask is + needed. + """ + # Given spatial mask + if mask is not None: + # check mask has matching type with target + assert (mask.ndim == target.ndim and all( + d_m == d_t or d_m == 1 + for d_m, d_t in zip(mask.shape, target.shape))), ( + f'mask and target have mismatched shapes {mask.shape} v.s.' + f'{target.shape}') + + # Mask by target weights (keypoint-wise mask) + if target_weights is not None: + # check target weight has matching shape with target + assert (target_weights.ndim in (2, 4) and target_weights.shape + == target.shape[:target_weights.ndim]), ( + 'target_weights and target have mismatched shapes ' + f'{target_weights.shape} v.s. {target.shape}') + + ndim_pad = target.ndim - target_weights.ndim + _mask = target_weights.view(target_weights.shape + + (1, ) * ndim_pad) + + if mask is None: + mask = _mask + else: + mask = mask * _mask + + # Mask by ``skip_empty_channel`` + if self.skip_empty_channel: + _mask = (target != 0).flatten(2).any() + ndim_pad = target.ndim - _mask.ndim + _mask = _mask.view(_mask.shape + (1, ) * ndim_pad) + + if mask is None: + mask = _mask + else: + mask = mask * _mask + + return mask + + +@MODELS.register_module() +class CombinedTargetMSELoss(nn.Module): + """MSE loss for combined target. + + CombinedTarget: The combination of classification target + (response map) and regression target (offset map). + Paper ref: Huang et al. The Devil is in the Details: Delving into + Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Args: + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + Defaults to ``False`` + loss_weight (float): Weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + use_target_weight: bool = False, + loss_weight: float = 1.): + super().__init__() + self.criterion = nn.MSELoss(reduction='mean') + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output: Tensor, target: Tensor, + target_weights: Tensor) -> Tensor: + """Forward function of loss. + + Note: + - batch_size: B + - num_channels: C + - heatmaps height: H + - heatmaps weight: W + - num_keypoints: K + Here, C = 3 * K + + Args: + output (Tensor): The output feature maps with shape [B, C, H, W]. + target (Tensor): The target feature maps with shape [B, C, H, W]. + target_weights (Tensor): The target weights of differet keypoints, + with shape [B, K]. + + Returns: + Tensor: The calculated loss. + """ + batch_size = output.size(0) + num_channels = output.size(1) + heatmaps_pred = output.reshape( + (batch_size, num_channels, -1)).split(1, 1) + heatmaps_gt = target.reshape( + (batch_size, num_channels, -1)).split(1, 1) + loss = 0. + num_joints = num_channels // 3 + for idx in range(num_joints): + heatmap_pred = heatmaps_pred[idx * 3].squeeze() + heatmap_gt = heatmaps_gt[idx * 3].squeeze() + offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() + offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() + offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() + offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() + if self.use_target_weight: + target_weight = target_weights[:, idx, None] + heatmap_pred = heatmap_pred * target_weight + heatmap_gt = heatmap_gt * target_weight + # classification loss + loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) + # regression loss + loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, + heatmap_gt * offset_x_gt) + loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, + heatmap_gt * offset_y_gt) + return loss / num_joints * self.loss_weight + + +@MODELS.register_module() +class KeypointOHKMMSELoss(nn.Module): + """MSE loss with online hard keypoint mining. + + Args: + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + Defaults to ``False`` + topk (int): Only top k joint losses are kept. Defaults to 8 + loss_weight (float): Weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + use_target_weight: bool = False, + topk: int = 8, + loss_weight: float = 1.): + super().__init__() + assert topk > 0 + self.criterion = nn.MSELoss(reduction='none') + self.use_target_weight = use_target_weight + self.topk = topk + self.loss_weight = loss_weight + + def _ohkm(self, losses: Tensor) -> Tensor: + """Online hard keypoint mining. + + Note: + - batch_size: B + - num_keypoints: K + + Args: + loss (Tensor): The losses with shape [B, K] + + Returns: + Tensor: The calculated loss. + """ + ohkm_loss = 0. + B = losses.shape[0] + for i in range(B): + sub_loss = losses[i] + _, topk_idx = torch.topk( + sub_loss, k=self.topk, dim=0, sorted=False) + tmp_loss = torch.gather(sub_loss, 0, topk_idx) + ohkm_loss += torch.sum(tmp_loss) / self.topk + ohkm_loss /= B + return ohkm_loss + + def forward(self, output: Tensor, target: Tensor, + target_weights: Tensor) -> Tensor: + """Forward function of loss. + + Note: + - batch_size: B + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (Tensor): The output heatmaps with shape [B, K, H, W]. + target (Tensor): The target heatmaps with shape [B, K, H, W]. + target_weights (Tensor): The target weights of differet keypoints, + with shape [B, K]. + + Returns: + Tensor: The calculated loss. + """ + num_keypoints = output.size(1) + if num_keypoints < self.topk: + raise ValueError(f'topk ({self.topk}) should not be ' + f'larger than num_keypoints ({num_keypoints}).') + + losses = [] + for idx in range(num_keypoints): + if self.use_target_weight: + target_weight = target_weights[:, idx, None, None] + losses.append( + self.criterion(output[:, idx] * target_weight, + target[:, idx] * target_weight)) + else: + losses.append(self.criterion(output[:, idx], target[:, idx])) + + losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses] + losses = torch.cat(losses, dim=1) + + return self._ohkm(losses) * self.loss_weight + + +@MODELS.register_module() +class AdaptiveWingLoss(nn.Module): + """Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face + Alignment via Heatmap Regression' Wang et al. ICCV'2019. + + Args: + alpha (float), omega (float), epsilon (float), theta (float) + are hyper-parameters. + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, + alpha=2.1, + omega=14, + epsilon=1, + theta=0.5, + use_target_weight=False, + loss_weight=1.): + super().__init__() + self.alpha = float(alpha) + self.omega = float(omega) + self.epsilon = float(epsilon) + self.theta = float(theta) + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def criterion(self, pred, target): + """Criterion of wingloss. + + Note: + batch_size: N + num_keypoints: K + + Args: + pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. + target (torch.Tensor[NxKxHxW]): Target heatmaps. + """ + H, W = pred.shape[2:4] + delta = (target - pred).abs() + + A = self.omega * ( + 1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) + ) * (self.alpha - target) * (torch.pow( + self.theta / self.epsilon, + self.alpha - target - 1)) * (1 / self.epsilon) + C = self.theta * A - self.omega * torch.log( + 1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) + + losses = torch.where( + delta < self.theta, + self.omega * + torch.log(1 + + torch.pow(delta / self.epsilon, self.alpha - target)), + A * delta - C) + + return torch.mean(losses) + + def forward(self, + output: Tensor, + target: Tensor, + target_weights: Optional[Tensor] = None): + """Forward function. + + Note: + batch_size: N + num_keypoints: K + + Args: + output (torch.Tensor[N, K, H, W]): Output heatmaps. + target (torch.Tensor[N, K, H, W]): Target heatmaps. + target_weight (torch.Tensor[N, K]): + Weights across different joint types. + """ + if self.use_target_weight: + assert (target_weights.ndim in (2, 4) and target_weights.shape + == target.shape[:target_weights.ndim]), ( + 'target_weights and target have mismatched shapes ' + f'{target_weights.shape} v.s. {target.shape}') + + ndim_pad = target.ndim - target_weights.ndim + target_weights = target_weights.view(target_weights.shape + + (1, ) * ndim_pad) + loss = self.criterion(output * target_weights, + target * target_weights) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class FocalHeatmapLoss(KeypointMSELoss): + """A class for calculating the modified focal loss for heatmap prediction. + + This loss function is exactly the same as the one used in CornerNet. It + runs faster and costs a little bit more memory. + + `CornerNet: Detecting Objects as Paired Keypoints + arXiv: `_. + + Arguments: + alpha (int): The alpha parameter in the focal loss equation. + beta (int): The beta parameter in the focal loss equation. + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + Defaults to ``False`` + skip_empty_channel (bool): If ``True``, heatmap channels with no + non-zero value (which means no visible ground-truth keypoint + in the image) will not be used to calculate the loss. Defaults to + ``False`` + loss_weight (float): Weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + alpha: int = 2, + beta: int = 4, + use_target_weight: bool = False, + skip_empty_channel: bool = False, + loss_weight: float = 1.0): + super(FocalHeatmapLoss, self).__init__(use_target_weight, + skip_empty_channel, loss_weight) + self.alpha = alpha + self.beta = beta + + def forward(self, + output: Tensor, + target: Tensor, + target_weights: Optional[Tensor] = None, + mask: Optional[Tensor] = None) -> Tensor: + """Calculate the modified focal loss for heatmap prediction. + + Note: + - batch_size: B + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (Tensor): The output heatmaps with shape [B, K, H, W] + target (Tensor): The target heatmaps with shape [B, K, H, W] + target_weights (Tensor, optional): The target weights of differet + keypoints, with shape [B, K] (keypoint-wise) or + [B, K, H, W] (pixel-wise). + mask (Tensor, optional): The masks of valid heatmap pixels in + shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will + be applied. Defaults to ``None`` + + Returns: + Tensor: The calculated loss. + """ + _mask = self._get_mask(target, target_weights, mask) + + pos_inds = target.eq(1).float() + neg_inds = target.lt(1).float() + + if _mask is not None: + pos_inds = pos_inds * _mask + neg_inds = neg_inds * _mask + + neg_weights = torch.pow(1 - target, self.beta) + + pos_loss = torch.log(output) * torch.pow(1 - output, + self.alpha) * pos_inds + neg_loss = torch.log(1 - output) * torch.pow( + output, self.alpha) * neg_weights * neg_inds + + num_pos = pos_inds.float().sum() + if num_pos == 0: + loss = -neg_loss.sum() + else: + loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos + return loss * self.loss_weight diff --git a/mmpose/models/losses/loss_wrappers.py b/mmpose/models/losses/loss_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d821661b48a133ffd6c9232d5a6a2d3eb6bf0a50 --- /dev/null +++ b/mmpose/models/losses/loss_wrappers.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +import torch.nn as nn + +from mmpose.registry import MODELS +from mmpose.utils.typing import ConfigType + + +@MODELS.register_module() +class MultipleLossWrapper(nn.Module): + """A wrapper to collect multiple loss functions together and return a list + of losses in the same order. + + Args: + losses (list): List of Loss Config + """ + + def __init__(self, losses: list): + super().__init__() + self.num_losses = len(losses) + + loss_modules = [] + for loss_cfg in losses: + t_loss = MODELS.build(loss_cfg) + loss_modules.append(t_loss) + self.loss_modules = nn.ModuleList(loss_modules) + + def forward(self, input_list, target_list, keypoint_weights=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + input_list (List[Tensor]): List of inputs. + target_list (List[Tensor]): List of targets. + keypoint_weights (Tensor[N, K, D]): + Weights across different joint types. + """ + assert isinstance(input_list, list), '' + assert isinstance(target_list, list), '' + assert len(input_list) == len(target_list), '' + + losses = [] + for i in range(self.num_losses): + input_i = input_list[i] + target_i = target_list[i] + + loss_i = self.loss_modules[i](input_i, target_i, keypoint_weights) + losses.append(loss_i) + + return losses + + +@MODELS.register_module() +class CombinedLoss(nn.ModuleDict): + """A wrapper to combine multiple loss functions. These loss functions can + have different input type (e.g. heatmaps or regression values), and can + only be involed individually and explixitly. + + Args: + losses (Dict[str, ConfigType]): The names and configs of loss + functions to be wrapped + + Example:: + >>> heatmap_loss_cfg = dict(type='KeypointMSELoss') + >>> ae_loss_cfg = dict(type='AssociativeEmbeddingLoss') + >>> loss_module = CombinedLoss( + ... losses=dict( + ... heatmap_loss=heatmap_loss_cfg, + ... ae_loss=ae_loss_cfg)) + >>> loss_hm = loss_module.heatmap_loss(pred_heatmap, gt_heatmap) + >>> loss_ae = loss_module.ae_loss(pred_tags, keypoint_indices) + """ + + def __init__(self, losses: Dict[str, ConfigType]): + super().__init__() + for loss_name, loss_cfg in losses.items(): + self.add_module(loss_name, MODELS.build(loss_cfg)) diff --git a/mmpose/models/losses/regression_loss.py b/mmpose/models/losses/regression_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9a64a4adfe9c7429d6e72b794d6c2b01af14fe46 --- /dev/null +++ b/mmpose/models/losses/regression_loss.py @@ -0,0 +1,618 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpose.registry import MODELS +from ..utils.realnvp import RealNVP + + +@MODELS.register_module() +class RLELoss(nn.Module): + """RLE Loss. + + `Human Pose Regression With Residual Log-Likelihood Estimation + arXiv: `_. + + Code is modified from `the official implementation + `_. + + Args: + use_target_weight (bool): Option to use weighted loss. + Different joint types may have different target weights. + size_average (bool): Option to average the loss by the batch_size. + residual (bool): Option to add L1 loss and let the flow + learn the residual error distribution. + q_dis (string): Option for the identity Q(error) distribution, + Options: "laplace" or "gaussian" + """ + + def __init__(self, + use_target_weight=False, + size_average=True, + residual=True, + q_distribution='laplace'): + super(RLELoss, self).__init__() + self.size_average = size_average + self.use_target_weight = use_target_weight + self.residual = residual + self.q_distribution = q_distribution + + self.flow_model = RealNVP() + + def forward(self, pred, sigma, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + pred (Tensor[N, K, D]): Output regression. + sigma (Tensor[N, K, D]): Output sigma. + target (Tensor[N, K, D]): Target regression. + target_weight (Tensor[N, K, D]): + Weights across different joint types. + """ + sigma = sigma.sigmoid() + + error = (pred - target) / (sigma + 1e-9) + # (B, K, 2) + log_phi = self.flow_model.log_prob(error.reshape(-1, 2)) + log_phi = log_phi.reshape(target.shape[0], target.shape[1], 1) + log_sigma = torch.log(sigma).reshape(target.shape[0], target.shape[1], + 2) + nf_loss = log_sigma - log_phi + + if self.residual: + assert self.q_distribution in ['laplace', 'gaussian'] + if self.q_distribution == 'laplace': + loss_q = torch.log(sigma * 2) + torch.abs(error) + else: + loss_q = torch.log( + sigma * math.sqrt(2 * math.pi)) + 0.5 * error**2 + + loss = nf_loss + loss_q + else: + loss = nf_loss + + if self.use_target_weight: + assert target_weight is not None + loss *= target_weight + + if self.size_average: + loss /= len(loss) + + return loss.sum() + + +@MODELS.register_module() +class SmoothL1Loss(nn.Module): + """SmoothL1Loss loss. + + Args: + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, use_target_weight=False, loss_weight=1.): + super().__init__() + self.criterion = F.smooth_l1_loss + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N, K, D]): + Weights across different joint types. + """ + + if self.use_target_weight: + assert target_weight is not None + assert output.ndim >= target_weight.ndim + + for i in range(output.ndim - target_weight.ndim): + target_weight = target_weight.unsqueeze(-1) + + loss = self.criterion(output * target_weight, + target * target_weight) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class SoftWeightSmoothL1Loss(nn.Module): + """Smooth L1 loss with soft weight for regression. + + Args: + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + supervise_empty (bool): Whether to supervise the output with zero + weight. + beta (float): Specifies the threshold at which to change between + L1 and L2 loss. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, + use_target_weight=False, + supervise_empty=True, + beta=1.0, + loss_weight=1.): + super().__init__() + + reduction = 'none' if use_target_weight else 'mean' + self.criterion = partial( + self.smooth_l1_loss, reduction=reduction, beta=beta) + + self.supervise_empty = supervise_empty + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + @staticmethod + def smooth_l1_loss(input, target, reduction='none', beta=1.0): + """Re-implement torch.nn.functional.smooth_l1_loss with beta to support + pytorch <= 1.6.""" + delta = input - target + mask = delta.abs() < beta + delta[mask] = (delta[mask]).pow(2) / (2 * beta) + delta[~mask] = delta[~mask].abs() - beta / 2 + + if reduction == 'mean': + return delta.mean() + elif reduction == 'sum': + return delta.sum() + elif reduction == 'none': + return delta + else: + raise ValueError(f'reduction must be \'mean\', \'sum\' or ' + f'\'none\', but got \'{reduction}\'') + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N, K, D]): + Weights across different joint types. + """ + if self.use_target_weight: + assert target_weight is not None + assert output.ndim >= target_weight.ndim + + for i in range(output.ndim - target_weight.ndim): + target_weight = target_weight.unsqueeze(-1) + + loss = self.criterion(output, target) * target_weight + if self.supervise_empty: + loss = loss.mean() + else: + num_elements = torch.nonzero(target_weight > 0).size()[0] + loss = loss.sum() / max(num_elements, 1.0) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class WingLoss(nn.Module): + """Wing Loss. paper ref: 'Wing Loss for Robust Facial Landmark Localisation + with Convolutional Neural Networks' Feng et al. CVPR'2018. + + Args: + omega (float): Also referred to as width. + epsilon (float): Also referred to as curvature. + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, + omega=10.0, + epsilon=2.0, + use_target_weight=False, + loss_weight=1.): + super().__init__() + self.omega = omega + self.epsilon = epsilon + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + # constant that smoothly links the piecewise-defined linear + # and nonlinear parts + self.C = self.omega * (1.0 - math.log(1.0 + self.omega / self.epsilon)) + + def criterion(self, pred, target): + """Criterion of wingloss. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + pred (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + """ + delta = (target - pred).abs() + losses = torch.where( + delta < self.omega, + self.omega * torch.log(1.0 + delta / self.epsilon), delta - self.C) + return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0) + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N,K,D]): + Weights across different joint types. + """ + if self.use_target_weight: + assert target_weight is not None + loss = self.criterion(output * target_weight, + target * target_weight) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class SoftWingLoss(nn.Module): + """Soft Wing Loss 'Structure-Coherent Deep Feature Learning for Robust Face + Alignment' Lin et al. TIP'2021. + + loss = + 1. |x| , if |x| < omega1 + 2. omega2*ln(1+|x|/epsilon) + B, if |x| >= omega1 + + Args: + omega1 (float): The first threshold. + omega2 (float): The second threshold. + epsilon (float): Also referred to as curvature. + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, + omega1=2.0, + omega2=20.0, + epsilon=0.5, + use_target_weight=False, + loss_weight=1.): + super().__init__() + self.omega1 = omega1 + self.omega2 = omega2 + self.epsilon = epsilon + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + # constant that smoothly links the piecewise-defined linear + # and nonlinear parts + self.B = self.omega1 - self.omega2 * math.log(1.0 + self.omega1 / + self.epsilon) + + def criterion(self, pred, target): + """Criterion of wingloss. + + Note: + batch_size: N + num_keypoints: K + dimension of keypoints: D (D=2 or D=3) + + Args: + pred (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + """ + delta = (target - pred).abs() + losses = torch.where( + delta < self.omega1, delta, + self.omega2 * torch.log(1.0 + delta / self.epsilon) + self.B) + return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0) + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + batch_size: N + num_keypoints: K + dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N, K, D]): + Weights across different joint types. + """ + if self.use_target_weight: + assert target_weight is not None + loss = self.criterion(output * target_weight, + target * target_weight) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class MPJPELoss(nn.Module): + """MPJPE (Mean Per Joint Position Error) loss. + + Args: + use_target_weight (bool): Option to use weighted MSE loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, use_target_weight=False, loss_weight=1.): + super().__init__() + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N,K,D]): + Weights across different joint types. + """ + + if self.use_target_weight: + assert target_weight is not None + loss = torch.mean( + torch.norm((output - target) * target_weight, dim=-1)) + else: + loss = torch.mean(torch.norm(output - target, dim=-1)) + + return loss * self.loss_weight + + +@MODELS.register_module() +class L1Loss(nn.Module): + """L1Loss loss .""" + + def __init__(self, use_target_weight=False, loss_weight=1.): + super().__init__() + self.criterion = F.l1_loss + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 2]): Output regression. + target (torch.Tensor[N, K, 2]): Target regression. + target_weight (torch.Tensor[N, K, 2]): + Weights across different joint types. + """ + if self.use_target_weight: + assert target_weight is not None + loss = self.criterion(output * target_weight, + target * target_weight) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class MSELoss(nn.Module): + """MSE loss for coordinate regression.""" + + def __init__(self, use_target_weight=False, loss_weight=1.): + super().__init__() + self.criterion = F.mse_loss + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 2]): Output regression. + target (torch.Tensor[N, K, 2]): Target regression. + target_weight (torch.Tensor[N, K, 2]): + Weights across different joint types. + """ + + if self.use_target_weight: + assert target_weight is not None + loss = self.criterion(output * target_weight, + target * target_weight) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + +@MODELS.register_module() +class BoneLoss(nn.Module): + """Bone length loss. + + Args: + joint_parents (list): Indices of each joint's parent joint. + use_target_weight (bool): Option to use weighted bone loss. + Different bone types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.): + super().__init__() + self.joint_parents = joint_parents + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + self.non_root_indices = [] + for i in range(len(self.joint_parents)): + if i != self.joint_parents[i]: + self.non_root_indices.append(i) + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N, K-1]): + Weights across different bone types. + """ + output_bone = torch.norm( + output - output[:, self.joint_parents, :], + dim=-1)[:, self.non_root_indices] + target_bone = torch.norm( + target - target[:, self.joint_parents, :], + dim=-1)[:, self.non_root_indices] + if self.use_target_weight: + assert target_weight is not None + loss = torch.mean( + torch.abs((output_bone * target_weight).mean(dim=0) - + (target_bone * target_weight).mean(dim=0))) + else: + loss = torch.mean( + torch.abs(output_bone.mean(dim=0) - target_bone.mean(dim=0))) + + return loss * self.loss_weight + + +@MODELS.register_module() +class SemiSupervisionLoss(nn.Module): + """Semi-supervision loss for unlabeled data. It is composed of projection + loss and bone loss. + + Paper ref: `3D human pose estimation in video with temporal convolutions + and semi-supervised training` Dario Pavllo et al. CVPR'2019. + + Args: + joint_parents (list): Indices of each joint's parent joint. + projection_loss_weight (float): Weight for projection loss. + bone_loss_weight (float): Weight for bone loss. + warmup_iterations (int): Number of warmup iterations. In the first + `warmup_iterations` iterations, the model is trained only on + labeled data, and semi-supervision loss will be 0. + This is a workaround since currently we cannot access + epoch number in loss functions. Note that the iteration number in + an epoch can be changed due to different GPU numbers in multi-GPU + settings. So please set this parameter carefully. + warmup_iterations = dataset_size // samples_per_gpu // gpu_num + * warmup_epochs + """ + + def __init__(self, + joint_parents, + projection_loss_weight=1., + bone_loss_weight=1., + warmup_iterations=0): + super().__init__() + self.criterion_projection = MPJPELoss( + loss_weight=projection_loss_weight) + self.criterion_bone = BoneLoss( + joint_parents, loss_weight=bone_loss_weight) + self.warmup_iterations = warmup_iterations + self.num_iterations = 0 + + @staticmethod + def project_joints(x, intrinsics): + """Project 3D joint coordinates to 2D image plane using camera + intrinsic parameters. + + Args: + x (torch.Tensor[N, K, 3]): 3D joint coordinates. + intrinsics (torch.Tensor[N, 4] | torch.Tensor[N, 9]): Camera + intrinsics: f (2), c (2), k (3), p (2). + """ + while intrinsics.dim() < x.dim(): + intrinsics.unsqueeze_(1) + f = intrinsics[..., :2] + c = intrinsics[..., 2:4] + _x = torch.clamp(x[:, :, :2] / x[:, :, 2:], -1, 1) + if intrinsics.shape[-1] == 9: + k = intrinsics[..., 4:7] + p = intrinsics[..., 7:9] + + r2 = torch.sum(_x[:, :, :2]**2, dim=-1, keepdim=True) + radial = 1 + torch.sum( + k * torch.cat((r2, r2**2, r2**3), dim=-1), + dim=-1, + keepdim=True) + tan = torch.sum(p * _x, dim=-1, keepdim=True) + _x = _x * (radial + tan) + p * r2 + _x = f * _x + c + return _x + + def forward(self, output, target): + losses = dict() + + self.num_iterations += 1 + if self.num_iterations <= self.warmup_iterations: + return losses + + labeled_pose = output['labeled_pose'] + unlabeled_pose = output['unlabeled_pose'] + unlabeled_traj = output['unlabeled_traj'] + unlabeled_target_2d = target['unlabeled_target_2d'] + intrinsics = target['intrinsics'] + + # projection loss + unlabeled_output = unlabeled_pose + unlabeled_traj + unlabeled_output_2d = self.project_joints(unlabeled_output, intrinsics) + loss_proj = self.criterion_projection(unlabeled_output_2d, + unlabeled_target_2d, None) + losses['proj_loss'] = loss_proj + + # bone loss + loss_bone = self.criterion_bone(unlabeled_pose, labeled_pose, None) + losses['bone_loss'] = loss_bone + + return losses diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f9105cb39cdfaba4f26d903b8bffbe05f4272a --- /dev/null +++ b/mmpose/models/necks/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .fmap_proc_neck import FeatureMapProcessor +from .fpn import FPN +from .gap_neck import GlobalAveragePooling +from .posewarper_neck import PoseWarperNeck + +__all__ = [ + 'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor' +] diff --git a/mmpose/models/necks/__pycache__/__init__.cpython-38.pyc b/mmpose/models/necks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2c4ded6e13005fc78afa8a00d53825f5622984e Binary files /dev/null and b/mmpose/models/necks/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/necks/__pycache__/fmap_proc_neck.cpython-38.pyc b/mmpose/models/necks/__pycache__/fmap_proc_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..278aa6129866721b6c25c3513cc14b7cae35d8b4 Binary files /dev/null and b/mmpose/models/necks/__pycache__/fmap_proc_neck.cpython-38.pyc differ diff --git a/mmpose/models/necks/__pycache__/fpn.cpython-38.pyc b/mmpose/models/necks/__pycache__/fpn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2807d62818c544d1dc35992223459fd32e01b9c Binary files /dev/null and b/mmpose/models/necks/__pycache__/fpn.cpython-38.pyc differ diff --git a/mmpose/models/necks/__pycache__/gap_neck.cpython-38.pyc b/mmpose/models/necks/__pycache__/gap_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3764a3b27eb7209f4a85a5695423861493cf2e1 Binary files /dev/null and b/mmpose/models/necks/__pycache__/gap_neck.cpython-38.pyc differ diff --git a/mmpose/models/necks/__pycache__/posewarper_neck.cpython-38.pyc b/mmpose/models/necks/__pycache__/posewarper_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..815a18a3620173178e9a9ef38fb5c70f6c879bc2 Binary files /dev/null and b/mmpose/models/necks/__pycache__/posewarper_neck.cpython-38.pyc differ diff --git a/mmpose/models/necks/fmap_proc_neck.py b/mmpose/models/necks/fmap_proc_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3a4d7bf44ab07641a4968f143e17c19b24743b --- /dev/null +++ b/mmpose/models/necks/fmap_proc_neck.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmpose.models.utils.ops import resize +from mmpose.registry import MODELS + + +@MODELS.register_module() +class FeatureMapProcessor(nn.Module): + """A PyTorch module for selecting, concatenating, and rescaling feature + maps. + + Args: + select_index (Optional[Union[int, Tuple[int]]], optional): Index or + indices of feature maps to select. Defaults to None, which means + all feature maps are used. + concat (bool, optional): Whether to concatenate the selected feature + maps. Defaults to False. + scale_factor (float, optional): The scaling factor to apply to the + feature maps. Defaults to 1.0. + apply_relu (bool, optional): Whether to apply ReLU on input feature + maps. Defaults to False. + align_corners (bool, optional): Whether to align corners when resizing + the feature maps. Defaults to False. + """ + + def __init__( + self, + select_index: Optional[Union[int, Tuple[int]]] = None, + concat: bool = False, + scale_factor: float = 1.0, + apply_relu: bool = False, + align_corners: bool = False, + ): + super().__init__() + + if isinstance(select_index, int): + select_index = (select_index, ) + self.select_index = select_index + self.concat = concat + + assert ( + scale_factor > 0 + ), f'the argument `scale_factor` must be positive, ' \ + f'but got {scale_factor}' + self.scale_factor = scale_factor + self.apply_relu = apply_relu + self.align_corners = align_corners + + def forward(self, inputs: Union[Tensor, Sequence[Tensor]] + ) -> Union[Tensor, List[Tensor]]: + + if not isinstance(inputs, (tuple, list)): + sequential_input = False + inputs = [inputs] + else: + sequential_input = True + + if self.select_index is not None: + inputs = [inputs[i] for i in self.select_index] + + if self.concat: + inputs = self._concat(inputs) + + if self.apply_relu: + inputs = [F.relu(x) for x in inputs] + + if self.scale_factor != 1.0: + inputs = self._rescale(inputs) + + if not sequential_input: + inputs = inputs[0] + + return inputs + + def _concat(self, inputs: Sequence[Tensor]) -> List[Tensor]: + size = inputs[0].shape[-2:] + resized_inputs = [ + resize( + x, + size=size, + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + return [torch.cat(resized_inputs, dim=1)] + + def _rescale(self, inputs: Sequence[Tensor]) -> List[Tensor]: + rescaled_inputs = [ + resize( + x, + scale_factor=self.scale_factor, + mode='bilinear', + align_corners=self.align_corners, + ) for x in inputs + ] + return rescaled_inputs diff --git a/mmpose/models/necks/fpn.py b/mmpose/models/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d3311bda792898dd1bc7ef9b9462db7b01ce05 --- /dev/null +++ b/mmpose/models/necks/fpn.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import xavier_init + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class FPN(nn.Module): + r"""Feature Pyramid Network. + + This is an implementation of paper `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest')): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + def init_weights(self): + """Initialize model weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return outs diff --git a/mmpose/models/necks/gap_neck.py b/mmpose/models/necks/gap_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..58ce5d939ffdeb912a02e8b1823ab073cbc3d9e3 --- /dev/null +++ b/mmpose/models/necks/gap_neck.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpose.registry import MODELS + + +@MODELS.register_module() +class GlobalAveragePooling(nn.Module): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + """ + + def __init__(self): + super().__init__() + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + + def init_weights(self): + pass + + def forward(self, inputs): + """Forward function.""" + + if isinstance(inputs, tuple): + outs = tuple([self.gap(x) for x in inputs]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, list): + outs = [self.gap(x) for x in inputs] + outs = [out.view(x.size(0), -1) for out, x in zip(outs, inputs)] + elif isinstance(inputs, torch.Tensor): + outs = self.gap(inputs) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpose/models/necks/posewarper_neck.py b/mmpose/models/necks/posewarper_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..517fabd2e839878e7cf692c91adad450f432e8f0 --- /dev/null +++ b/mmpose/models/necks/posewarper_neck.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import constant_init, normal_init +from mmengine.utils import digit_version +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.models.utils.ops import resize +from mmpose.registry import MODELS +from ..backbones.resnet import BasicBlock, Bottleneck + +try: + from mmcv.ops import DeformConv2d + has_mmcv_full = True +except (ImportError, ModuleNotFoundError): + has_mmcv_full = False + + +@MODELS.register_module() +class PoseWarperNeck(nn.Module): + """PoseWarper neck. + + `"Learning temporal pose estimation from sparsely-labeled videos" + `_. + + Args: + in_channels (int): Number of input channels from backbone + out_channels (int): Number of output channels + inner_channels (int): Number of intermediate channels of the res block + deform_groups (int): Number of groups in the deformable conv + dilations (list|tuple): different dilations of the offset conv layers + trans_conv_kernel (int): the kernel of the trans conv layer, which is + used to get heatmap from the output of backbone. Default: 1 + res_blocks_cfg (dict|None): config of residual blocks. If None, + use the default values. If not None, it should contain the + following keys: + + - block (str): the type of residual block, Default: 'BASIC'. + - num_blocks (int): the number of blocks, Default: 20. + + offsets_kernel (int): the kernel of offset conv layer. + deform_conv_kernel (int): the kernel of defomrable conv layer. + in_index (int|Sequence[int]): Input feature index. Default: 0 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + Default: None. + + - 'resize_concat': Multiple feature maps will be resize to \ + the same size as first one and than concat together. \ + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into \ + a list and passed into decode head. + - None: Only one select feature map is allowed. + + freeze_trans_layer (bool): Whether to freeze the transition layer + (stop grad and set eval mode). Default: True. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + im2col_step (int): the argument `im2col_step` in deformable conv, + Default: 80. + """ + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + minimum_mmcv_version = '1.3.17' + + def __init__(self, + in_channels, + out_channels, + inner_channels, + deform_groups=17, + dilations=(3, 6, 12, 18, 24), + trans_conv_kernel=1, + res_blocks_cfg=None, + offsets_kernel=3, + deform_conv_kernel=3, + in_index=0, + input_transform=None, + freeze_trans_layer=True, + norm_eval=False, + im2col_step=80): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.inner_channels = inner_channels + self.deform_groups = deform_groups + self.dilations = dilations + self.trans_conv_kernel = trans_conv_kernel + self.res_blocks_cfg = res_blocks_cfg + self.offsets_kernel = offsets_kernel + self.deform_conv_kernel = deform_conv_kernel + self.in_index = in_index + self.input_transform = input_transform + self.freeze_trans_layer = freeze_trans_layer + self.norm_eval = norm_eval + self.im2col_step = im2col_step + + identity_trans_layer = False + + assert trans_conv_kernel in [0, 1, 3] + kernel_size = trans_conv_kernel + if kernel_size == 3: + padding = 1 + elif kernel_size == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_trans_layer = True + + if identity_trans_layer: + self.trans_layer = nn.Identity() + else: + self.trans_layer = build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding) + + # build chain of residual blocks + if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict): + raise TypeError('res_blocks_cfg should be dict or None.') + + if res_blocks_cfg is None: + block_type = 'BASIC' + num_blocks = 20 + else: + block_type = res_blocks_cfg.get('block', 'BASIC') + num_blocks = res_blocks_cfg.get('num_blocks', 20) + + block = self.blocks_dict[block_type] + + res_layers = [] + downsample = nn.Sequential( + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=out_channels, + out_channels=inner_channels, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(dict(type='BN'), inner_channels)[1]) + res_layers.append( + block( + in_channels=out_channels, + out_channels=inner_channels, + downsample=downsample)) + + for _ in range(1, num_blocks): + res_layers.append(block(inner_channels, inner_channels)) + self.offset_feats = nn.Sequential(*res_layers) + + # build offset layers + self.num_offset_layers = len(dilations) + assert self.num_offset_layers > 0, 'Number of offset layers ' \ + 'should be larger than 0.' + + target_offset_channels = 2 * offsets_kernel**2 * deform_groups + + offset_layers = [ + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=inner_channels, + out_channels=target_offset_channels, + kernel_size=offsets_kernel, + stride=1, + dilation=dilations[i], + padding=dilations[i], + bias=False, + ) for i in range(self.num_offset_layers) + ] + self.offset_layers = nn.ModuleList(offset_layers) + + # build deformable conv layers + assert digit_version(mmcv.__version__) >= \ + digit_version(self.minimum_mmcv_version), \ + f'Current MMCV version: {mmcv.__version__}, ' \ + f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \ + f'https://github.com/open-mmlab/mmcv/issues/1440, ' \ + f'Please install the latest MMCV.' + + if has_mmcv_full: + deform_conv_layers = [ + DeformConv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=deform_conv_kernel, + stride=1, + padding=int(deform_conv_kernel / 2) * dilations[i], + dilation=dilations[i], + deform_groups=deform_groups, + im2col_step=self.im2col_step, + ) for i in range(self.num_offset_layers) + ] + else: + raise ImportError('Please install the full version of mmcv ' + 'to use `DeformConv2d`.') + + self.deform_conv_layers = nn.ModuleList(deform_conv_layers) + + self.freeze_layers() + + def freeze_layers(self): + if self.freeze_trans_layer: + self.trans_layer.eval() + + for param in self.trans_layer.parameters(): + param.requires_grad = False + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + elif isinstance(m, DeformConv2d): + filler = torch.zeros([ + m.weight.size(0), + m.weight.size(1), + m.weight.size(2), + m.weight.size(3) + ], + dtype=torch.float32, + device=m.weight.device) + for k in range(m.weight.size(0)): + filler[k, k, + int(m.weight.size(2) / 2), + int(m.weight.size(3) / 2)] = 1.0 + m.weight = torch.nn.Parameter(filler) + m.weight.requires_grad = True + + # posewarper offset layer weight initialization + for m in self.offset_layers.modules(): + constant_init(m, 0) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor] | Tensor): multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + if not isinstance(inputs, list): + return inputs + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def forward(self, inputs, frame_weight): + assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \ + 'should be list or tuple, even though the length is 1, ' \ + 'for unified processing.' + + output_heatmap = 0 + if len(inputs) > 1: + inputs = [self._transform_inputs(input) for input in inputs] + inputs = [self.trans_layer(input) for input in inputs] + + # calculate difference features + diff_features = [ + self.offset_feats(inputs[0] - input) for input in inputs + ] + + for i in range(len(inputs)): + if frame_weight[i] == 0: + continue + warped_heatmap = 0 + for j in range(self.num_offset_layers): + offset = (self.offset_layers[j](diff_features[i])) + warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i], + offset) + warped_heatmap += warped_heatmap_tmp / \ + self.num_offset_layers + + output_heatmap += warped_heatmap * frame_weight[i] + + else: + inputs = inputs[0] + inputs = self._transform_inputs(inputs) + inputs = self.trans_layer(inputs) + + num_frames = len(frame_weight) + batch_size = inputs.size(0) // num_frames + ref_x = inputs[:batch_size] + ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1) + + offset_features = self.offset_feats(ref_x_tiled - inputs) + + warped_heatmap = 0 + for j in range(self.num_offset_layers): + offset = self.offset_layers[j](offset_features) + + warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset) + warped_heatmap += warped_heatmap_tmp / self.num_offset_layers + + for i in range(num_frames): + if frame_weight[i] == 0: + continue + output_heatmap += warped_heatmap[i * batch_size:(i + 1) * + batch_size] * frame_weight[i] + + return output_heatmap + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self.freeze_layers() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpose/models/pose_estimators/__init__.py b/mmpose/models/pose_estimators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ead1a979e5a8c38ae98240b181e425ab7eeed35 --- /dev/null +++ b/mmpose/models/pose_estimators/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bottomup import BottomupPoseEstimator +from .topdown import TopdownPoseEstimator + +__all__ = ['TopdownPoseEstimator', 'BottomupPoseEstimator'] diff --git a/mmpose/models/pose_estimators/__pycache__/__init__.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a039392f447e0ebd7daf846a51f579fb2b1e5da Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/pose_estimators/__pycache__/base.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c6404a265eed582ffe953e0c9be756c5db17e14 Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/base.cpython-38.pyc differ diff --git a/mmpose/models/pose_estimators/__pycache__/bottomup.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/bottomup.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4fbfe2e96c4b850e0e69a4d579dee03f27ee3db Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/bottomup.cpython-38.pyc differ diff --git a/mmpose/models/pose_estimators/__pycache__/topdown.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/topdown.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d1cedc346a0e34064780c2f66e497ea75de9491 Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/topdown.cpython-38.pyc differ diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py new file mode 100644 index 0000000000000000000000000000000000000000..73d60de93a245a37093a723408fc25bc5370b448 --- /dev/null +++ b/mmpose/models/pose_estimators/base.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Tuple, Union + +import torch +from mmengine.model import BaseModel +from torch import Tensor + +from mmpose.datasets.datasets.utils import parse_pose_metainfo +from mmpose.models.utils import check_and_update_config +from mmpose.registry import MODELS +from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType, + Optional, OptMultiConfig, OptSampleList, + SampleList) + + +class BasePoseEstimator(BaseModel, metaclass=ABCMeta): + """Base class for pose estimators. + + Args: + data_preprocessor (dict | ConfigDict, optional): The pre-processing + config of :class:`BaseDataPreprocessor`. Defaults to ``None`` + init_cfg (dict | ConfigDict): The model initialization config. + Defaults to ``None`` + metainfo (dict): Meta information for dataset, such as keypoints + definition and properties. If set, the metainfo of the input data + batch will be overridden. For more details, please refer to + https://mmpose.readthedocs.io/en/latest/user_guides/ + prepare_datasets.html#create-a-custom-dataset-info- + config-file-for-the-dataset. Defaults to ``None`` + """ + _version = 2 + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + metainfo: Optional[dict] = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.metainfo = self._load_metainfo(metainfo) + + self.backbone = MODELS.build(backbone) + + # the PR #2108 and #2126 modified the interface of neck and head. + # The following function automatically detects outdated + # configurations and updates them accordingly, while also providing + # clear and concise information on the changes made. + neck, head = check_and_update_config(neck, head) + + if neck is not None: + self.neck = MODELS.build(neck) + + if head is not None: + self.head = MODELS.build(head) + + self.train_cfg = train_cfg if train_cfg else {} + self.test_cfg = test_cfg if test_cfg else {} + + # Register the hook to automatically convert old version state dicts + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + @property + def with_neck(self) -> bool: + """bool: whether the pose estimator has a neck.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """bool: whether the pose estimator has a head.""" + return hasattr(self, 'head') and self.head is not None + + @staticmethod + def _load_metainfo(metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + metainfo (dict): Raw data of pose meta information. + + Returns: + dict: Parsed meta information. + """ + + if metainfo is None: + return None + + if not isinstance(metainfo, dict): + raise TypeError( + f'metainfo should be a dict, but got {type(metainfo)}') + + metainfo = parse_pose_metainfo(metainfo) + return metainfo + + def forward(self, + inputs: torch.Tensor, + data_samples: OptSampleList, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: 'tensor', 'predict' and 'loss': + + - 'tensor': Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - 'predict': Forward and return the predictions, which are fully + processed to a list of :obj:`PoseDataSample`. + - 'loss': Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general + data_samples (list[:obj:`PoseDataSample`], optional): The + annotation of every sample. Defaults to ``None`` + mode (str): Set the forward mode and return value type. Defaults + to ``'tensor'`` + + Returns: + The return type depends on ``mode``. + + - If ``mode='tensor'``, return a tensor or a tuple of tensors + - If ``mode='predict'``, return a list of :obj:``PoseDataSample`` + that contains the pose predictions + - If ``mode='loss'``, return a dict of tensor(s) which is the loss + function value + """ + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + # use customed metainfo to override the default metainfo + if self.metainfo is not None: + for data_sample in data_samples: + data_sample.set_metainfo(self.metainfo) + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode.') + + @abstractmethod + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + @abstractmethod + def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing.""" + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None + ) -> Union[Tensor, Tuple[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + + Returns: + Union[Tensor | Tuple[Tensor]]: forward output of the network. + """ + + x = self.extract_feat(inputs) + if self.with_head: + x = self.head.forward(x) + + return x + + def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have various + resolutions. + """ + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + + return x + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to convert old-version state dict of + :class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a + compatible format of :class:`HeatmapHead`. + + The hook will be automatically registered during initialization. + """ + version = local_meta.get('version', None) + if version and version >= self._version: + return + + # convert old-version state dict + keys = list(state_dict.keys()) + for k in keys: + if 'keypoint_head' in k: + v = state_dict.pop(k) + k = k.replace('keypoint_head', 'head') + state_dict[k] = v diff --git a/mmpose/models/pose_estimators/bottomup.py b/mmpose/models/pose_estimators/bottomup.py new file mode 100644 index 0000000000000000000000000000000000000000..5400f2478e411bbf39b875ad2c8bfd6532ffa4b8 --- /dev/null +++ b/mmpose/models/pose_estimators/bottomup.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import zip_longest +from typing import List, Optional, Union + +from mmengine.utils import is_list_of +from torch import Tensor + +from mmpose.registry import MODELS +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + OptMultiConfig, PixelDataList, SampleList) +from .base import BasePoseEstimator + + +@MODELS.register_module() +class BottomupPoseEstimator(BasePoseEstimator): + """Base class for bottom-up pose estimators. + + Args: + backbone (dict): The backbone config + neck (dict, optional): The neck config. Defaults to ``None`` + head (dict, optional): The head config. Defaults to ``None`` + train_cfg (dict, optional): The runtime config for training process. + Defaults to ``None`` + test_cfg (dict, optional): The runtime config for testing process. + Defaults to ``None`` + data_preprocessor (dict, optional): The data preprocessing config to + build the instance of :class:`BaseDataPreprocessor`. Defaults to + ``None``. + init_cfg (dict, optional): The config to control the initialization. + Defaults to ``None`` + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + head=head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`PoseDataSample`]): The batch + data samples. + + Returns: + dict: A dictionary of losses. + """ + feats = self.extract_feat(inputs) + + losses = dict() + + if self.with_head: + losses.update( + self.head.loss(feats, data_samples, train_cfg=self.train_cfg)) + + return losses + + def predict(self, inputs: Union[Tensor, List[Tensor]], + data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor | List[Tensor]): Input image in tensor or image + pyramid as a list of tensors. Each tensor is in shape + [B, C, H, W] + data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + + Returns: + list[:obj:`PoseDataSample`]: The pose estimation results of the + input images. The return value is `PoseDataSample` instances with + ``pred_instances`` and ``pred_fields``(optional) field , and + ``pred_instances`` usually contains the following keys: + + - keypoints (Tensor): predicted keypoint coordinates in shape + (num_instances, K, D) where K is the keypoint number and D + is the keypoint dimension + - keypoint_scores (Tensor): predicted keypoint scores in shape + (num_instances, K) + """ + assert self.with_head, ( + 'The model must have head to perform prediction.') + + multiscale_test = self.test_cfg.get('multiscale_test', False) + flip_test = self.test_cfg.get('flip_test', False) + + # enable multi-scale test + aug_scales = data_samples[0].metainfo.get('aug_scales', None) + if multiscale_test: + assert isinstance(aug_scales, list) + assert is_list_of(inputs, Tensor) + # `inputs` includes images in original and augmented scales + assert len(inputs) == len(aug_scales) + 1 + else: + assert isinstance(inputs, Tensor) + # single-scale test + inputs = [inputs] + + feats = [] + for _inputs in inputs: + if flip_test: + _feats_orig = self.extract_feat(_inputs) + _feats_flip = self.extract_feat(_inputs.flip(-1)) + _feats = [_feats_orig, _feats_flip] + else: + _feats = self.extract_feat(_inputs) + + feats.append(_feats) + + if not multiscale_test: + feats = feats[0] + + preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg) + + if isinstance(preds, tuple): + batch_pred_instances, batch_pred_fields = preds + else: + batch_pred_instances = preds + batch_pred_fields = None + + results = self.add_pred_to_datasample(batch_pred_instances, + batch_pred_fields, data_samples) + + return results + + def add_pred_to_datasample(self, batch_pred_instances: InstanceList, + batch_pred_fields: Optional[PixelDataList], + batch_data_samples: SampleList) -> SampleList: + """Add predictions into data samples. + + Args: + batch_pred_instances (List[InstanceData]): The predicted instances + of the input data batch + batch_pred_fields (List[PixelData], optional): The predicted + fields (e.g. heatmaps) of the input batch + batch_data_samples (List[PoseDataSample]): The input data batch + + Returns: + List[PoseDataSample]: A list of data samples where the predictions + are stored in the ``pred_instances`` field of each data sample. + The length of the list is the batch size when ``merge==False``, or + 1 when ``merge==True``. + """ + assert len(batch_pred_instances) == len(batch_data_samples) + if batch_pred_fields is None: + batch_pred_fields = [] + + for pred_instances, pred_fields, data_sample in zip_longest( + batch_pred_instances, batch_pred_fields, batch_data_samples): + + # convert keypoint coordinates from input space to image space + input_size = data_sample.metainfo['input_size'] + input_center = data_sample.metainfo['input_center'] + input_scale = data_sample.metainfo['input_scale'] + + pred_instances.keypoints = pred_instances.keypoints / input_size \ + * input_scale + input_center - 0.5 * input_scale + + data_sample.pred_instances = pred_instances + + if pred_fields is not None: + data_sample.pred_fields = pred_fields + + return batch_data_samples diff --git a/mmpose/models/pose_estimators/topdown.py b/mmpose/models/pose_estimators/topdown.py new file mode 100644 index 0000000000000000000000000000000000000000..89b332893f0c4e051084a9eb86addf75433c0534 --- /dev/null +++ b/mmpose/models/pose_estimators/topdown.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import zip_longest +from typing import Optional + +from torch import Tensor + +from mmpose.registry import MODELS +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + OptMultiConfig, PixelDataList, SampleList) +from .base import BasePoseEstimator + + +@MODELS.register_module() +class TopdownPoseEstimator(BasePoseEstimator): + """Base class for top-down pose estimators. + + Args: + backbone (dict): The backbone config + neck (dict, optional): The neck config. Defaults to ``None`` + head (dict, optional): The head config. Defaults to ``None`` + train_cfg (dict, optional): The runtime config for training process. + Defaults to ``None`` + test_cfg (dict, optional): The runtime config for testing process. + Defaults to ``None`` + data_preprocessor (dict, optional): The data preprocessing config to + build the instance of :class:`BaseDataPreprocessor`. Defaults to + ``None`` + init_cfg (dict, optional): The config to control the initialization. + Defaults to ``None`` + metainfo (dict): Meta information for dataset, such as keypoints + definition and properties. If set, the metainfo of the input data + batch will be overridden. For more details, please refer to + https://mmpose.readthedocs.io/en/latest/user_guides/ + prepare_datasets.html#create-a-custom-dataset-info- + config-file-for-the-dataset. Defaults to ``None`` + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + metainfo: Optional[dict] = None): + super().__init__( + backbone=backbone, + neck=neck, + head=head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg, + metainfo=metainfo) + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`PoseDataSample`]): The batch + data samples. + + Returns: + dict: A dictionary of losses. + """ + feats = self.extract_feat(inputs) + + losses = dict() + + if self.with_head: + losses.update( + self.head.loss(feats, data_samples, train_cfg=self.train_cfg)) + + return losses + + def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W) + data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + + Returns: + list[:obj:`PoseDataSample`]: The pose estimation results of the + input images. The return value is `PoseDataSample` instances with + ``pred_instances`` and ``pred_fields``(optional) field , and + ``pred_instances`` usually contains the following keys: + + - keypoints (Tensor): predicted keypoint coordinates in shape + (num_instances, K, D) where K is the keypoint number and D + is the keypoint dimension + - keypoint_scores (Tensor): predicted keypoint scores in shape + (num_instances, K) + """ + assert self.with_head, ( + 'The model must have head to perform prediction.') + + if self.test_cfg.get('flip_test', False): + _feats = self.extract_feat(inputs) + _feats_flip = self.extract_feat(inputs.flip(-1)) + feats = [_feats, _feats_flip] + else: + feats = self.extract_feat(inputs) + + preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg) + + if isinstance(preds, tuple): + batch_pred_instances, batch_pred_fields = preds + else: + batch_pred_instances = preds + batch_pred_fields = None + + results = self.add_pred_to_datasample(batch_pred_instances, + batch_pred_fields, data_samples) + + return results + + def add_pred_to_datasample(self, batch_pred_instances: InstanceList, + batch_pred_fields: Optional[PixelDataList], + batch_data_samples: SampleList) -> SampleList: + """Add predictions into data samples. + + Args: + batch_pred_instances (List[InstanceData]): The predicted instances + of the input data batch + batch_pred_fields (List[PixelData], optional): The predicted + fields (e.g. heatmaps) of the input batch + batch_data_samples (List[PoseDataSample]): The input data batch + + Returns: + List[PoseDataSample]: A list of data samples where the predictions + are stored in the ``pred_instances`` field of each data sample. + """ + assert len(batch_pred_instances) == len(batch_data_samples) + if batch_pred_fields is None: + batch_pred_fields = [] + output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', + None) + + for pred_instances, pred_fields, data_sample in zip_longest( + batch_pred_instances, batch_pred_fields, batch_data_samples): + + gt_instances = data_sample.gt_instances + + # convert keypoint coordinates from input space to image space + bbox_centers = gt_instances.bbox_centers + bbox_scales = gt_instances.bbox_scales + input_size = data_sample.metainfo['input_size'] + + pred_instances.keypoints = pred_instances.keypoints / input_size \ + * bbox_scales + bbox_centers - 0.5 * bbox_scales + + if output_keypoint_indices is not None: + # select output keypoints with given indices + num_keypoints = pred_instances.keypoints.shape[1] + for key, value in pred_instances.all_items(): + if key.startswith('keypoint'): + pred_instances.set_field( + value[:, output_keypoint_indices], key) + + # add bbox information into pred_instances + pred_instances.bboxes = gt_instances.bboxes + pred_instances.bbox_scores = gt_instances.bbox_scores + + data_sample.pred_instances = pred_instances + + if pred_fields is not None: + if output_keypoint_indices is not None: + # select output heatmap channels with keypoint indices + # when the number of heatmap channel matches num_keypoints + for key, value in pred_fields.all_items(): + if value.shape[0] != num_keypoints: + continue + pred_fields.set_field(value[output_keypoint_indices], + key) + data_sample.pred_fields = pred_fields + + return batch_data_samples diff --git a/mmpose/models/utils/__init__.py b/mmpose/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22d8a89b41da0be83756f2f7a2e1a6194eb7cd5a --- /dev/null +++ b/mmpose/models/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .check_and_update_config import check_and_update_config +from .ckpt_convert import pvt_convert +from .rtmcc_block import RTMCCBlock, rope +from .transformer import PatchEmbed, nchw_to_nlc, nlc_to_nchw + +__all__ = [ + 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert', 'RTMCCBlock', + 'rope', 'check_and_update_config' +] diff --git a/mmpose/models/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/models/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dcf32a76e1b8c138f5b12c0af80e543e9a96aba Binary files /dev/null and b/mmpose/models/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/check_and_update_config.cpython-38.pyc b/mmpose/models/utils/__pycache__/check_and_update_config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95204c2631a0b8893555c556a72136404950711e Binary files /dev/null and b/mmpose/models/utils/__pycache__/check_and_update_config.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/ckpt_convert.cpython-38.pyc b/mmpose/models/utils/__pycache__/ckpt_convert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79fa68fc47c6bb63a1edcc2657c804ff9c10728f Binary files /dev/null and b/mmpose/models/utils/__pycache__/ckpt_convert.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/ops.cpython-38.pyc b/mmpose/models/utils/__pycache__/ops.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46d89f44788765d578d4f5c1a5705fa742ba2f5b Binary files /dev/null and b/mmpose/models/utils/__pycache__/ops.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/realnvp.cpython-38.pyc b/mmpose/models/utils/__pycache__/realnvp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fc7f095b108530daced9690f47553b4e8fb133b Binary files /dev/null and b/mmpose/models/utils/__pycache__/realnvp.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/regularizations.cpython-38.pyc b/mmpose/models/utils/__pycache__/regularizations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca329fdab6ee0a888080a0999887f3088f3b767 Binary files /dev/null and b/mmpose/models/utils/__pycache__/regularizations.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/rtmcc_block.cpython-38.pyc b/mmpose/models/utils/__pycache__/rtmcc_block.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b29af37d072c3993be24a19958ce30473ef6bb59 Binary files /dev/null and b/mmpose/models/utils/__pycache__/rtmcc_block.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/transformer.cpython-38.pyc b/mmpose/models/utils/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..932e3120b45f06de03258c467d50f0b8ba64d025 Binary files /dev/null and b/mmpose/models/utils/__pycache__/transformer.cpython-38.pyc differ diff --git a/mmpose/models/utils/__pycache__/tta.cpython-38.pyc b/mmpose/models/utils/__pycache__/tta.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76f305d1dc70386039b5dbee8309ce974258e90f Binary files /dev/null and b/mmpose/models/utils/__pycache__/tta.cpython-38.pyc differ diff --git a/mmpose/models/utils/check_and_update_config.py b/mmpose/models/utils/check_and_update_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd1efa39b584a08055d470343549349907c1a5c --- /dev/null +++ b/mmpose/models/utils/check_and_update_config.py @@ -0,0 +1,230 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple, Union + +from mmengine.config import Config, ConfigDict +from mmengine.dist import master_only +from mmengine.logging import MMLogger + +ConfigType = Union[Config, ConfigDict] + + +def process_input_transform(input_transform: str, head: Dict, head_new: Dict, + head_deleted_dict: Dict, head_append_dict: Dict, + neck_new: Dict, input_index: Tuple[int], + align_corners: bool) -> None: + """Process the input_transform field and update head and neck + dictionaries.""" + if input_transform == 'resize_concat': + in_channels = head_new.pop('in_channels') + head_deleted_dict['in_channels'] = str(in_channels) + in_channels = sum([in_channels[i] for i in input_index]) + head_new['in_channels'] = in_channels + head_append_dict['in_channels'] = str(in_channels) + + neck_new.update( + dict( + type='FeatureMapProcessor', + concat=True, + select_index=input_index, + )) + if align_corners: + neck_new['align_corners'] = align_corners + + elif input_transform == 'select': + if input_index != (-1, ): + neck_new.update( + dict(type='FeatureMapProcessor', select_index=input_index)) + if isinstance(head['in_channels'], tuple): + in_channels = head_new.pop('in_channels') + head_deleted_dict['in_channels'] = str(in_channels) + if isinstance(input_index, int): + in_channels = in_channels[input_index] + else: + in_channels = tuple([in_channels[i] for i in input_index]) + head_new['in_channels'] = in_channels + head_append_dict['in_channels'] = str(in_channels) + if align_corners: + neck_new['align_corners'] = align_corners + + else: + raise ValueError(f'model.head get invalid value for argument ' + f'input_transform: {input_transform}') + + +def process_extra_field(extra: Dict, head_new: Dict, head_deleted_dict: Dict, + head_append_dict: Dict, neck_new: Dict) -> None: + """Process the extra field and update head and neck dictionaries.""" + head_deleted_dict['extra'] = 'dict(' + for key, value in extra.items(): + head_deleted_dict['extra'] += f'{key}={value},' + head_deleted_dict['extra'] = head_deleted_dict['extra'][:-1] + ')' + if 'final_conv_kernel' in extra: + kernel_size = extra['final_conv_kernel'] + if kernel_size > 1: + padding = kernel_size // 2 + head_new['final_layer'] = dict( + kernel_size=kernel_size, padding=padding) + head_append_dict[ + 'final_layer'] = f'dict(kernel_size={kernel_size}, ' \ + f'padding={padding})' + else: + head_new['final_layer'] = dict(kernel_size=kernel_size) + head_append_dict[ + 'final_layer'] = f'dict(kernel_size={kernel_size})' + if 'upsample' in extra: + neck_new.update( + dict( + type='FeatureMapProcessor', + scale_factor=float(extra['upsample']), + apply_relu=True, + )) + + +def process_has_final_layer(has_final_layer: bool, head_new: Dict, + head_deleted_dict: Dict, + head_append_dict: Dict) -> None: + """Process the has_final_layer field and update the head dictionary.""" + head_deleted_dict['has_final_layer'] = str(has_final_layer) + if not has_final_layer: + if 'final_layer' not in head_new: + head_new['final_layer'] = None + head_append_dict['final_layer'] = 'None' + + +def check_and_update_config(neck: Optional[ConfigType], + head: ConfigType) -> Tuple[Optional[Dict], Dict]: + """Check and update the configuration of the head and neck components. + Args: + neck (Optional[ConfigType]): Configuration for the neck component. + head (ConfigType): Configuration for the head component. + + Returns: + Tuple[Optional[Dict], Dict]: Updated configurations for the neck + and head components. + """ + head_new, neck_new = head.copy(), neck.copy() if isinstance(neck, + dict) else {} + head_deleted_dict, head_append_dict = {}, {} + + if 'input_transform' in head: + input_transform = head_new.pop('input_transform') + head_deleted_dict['input_transform'] = f'\'{input_transform}\'' + else: + input_transform = 'select' + + if 'input_index' in head: + input_index = head_new.pop('input_index') + head_deleted_dict['input_index'] = str(input_index) + else: + input_index = (-1, ) + + if 'align_corners' in head: + align_corners = head_new.pop('align_corners') + head_deleted_dict['align_corners'] = str(align_corners) + else: + align_corners = False + + process_input_transform(input_transform, head, head_new, head_deleted_dict, + head_append_dict, neck_new, input_index, + align_corners) + + if 'extra' in head: + extra = head_new.pop('extra') + process_extra_field(extra, head_new, head_deleted_dict, + head_append_dict, neck_new) + + if 'has_final_layer' in head: + has_final_layer = head_new.pop('has_final_layer') + process_has_final_layer(has_final_layer, head_new, head_deleted_dict, + head_append_dict) + + display_modifications(head_deleted_dict, head_append_dict, neck_new) + + neck_new = neck_new if len(neck_new) else None + return neck_new, head_new + + +@master_only +def display_modifications(head_deleted_dict: Dict, head_append_dict: Dict, + neck: Dict) -> None: + """Display the modifications made to the head and neck configurations. + + Args: + head_deleted_dict (Dict): Dictionary of deleted fields in the head. + head_append_dict (Dict): Dictionary of appended fields in the head. + neck (Dict): Updated neck configuration. + """ + if len(head_deleted_dict) + len(head_append_dict) == 0: + return + + old_model_info, new_model_info = build_model_info(head_deleted_dict, + head_append_dict, neck) + + total_info = '\nThe config you are using is outdated. '\ + 'The following section of the config:\n```\n' + total_info += old_model_info + total_info += '```\nshould be updated to\n```\n' + total_info += new_model_info + total_info += '```\nFor more information, please refer to '\ + 'https://mmpose.readthedocs.io/en/latest/' \ + 'guide_to_framework.html#step3-model' + + logger: MMLogger = MMLogger.get_current_instance() + logger.warning(total_info) + + +def build_model_info(head_deleted_dict: Dict, head_append_dict: Dict, + neck: Dict) -> Tuple[str, str]: + """Build the old and new model information strings. + Args: + head_deleted_dict (Dict): Dictionary of deleted fields in the head. + head_append_dict (Dict): Dictionary of appended fields in the head. + neck (Dict): Updated neck configuration. + + Returns: + Tuple[str, str]: Old and new model information strings. + """ + old_head_info = build_head_info(head_deleted_dict) + new_head_info = build_head_info(head_append_dict) + neck_info = build_neck_info(neck) + + old_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' + old_head_info + new_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' \ + + neck_info + new_head_info + + return old_model_info, new_model_info + + +def build_head_info(head_dict: Dict) -> str: + """Build the head information string. + + Args: + head_dict (Dict): Dictionary of fields in the head configuration. + Returns: + str: Head information string. + """ + head_info = ' ' * 4 + 'head=dict(\n' + for key, value in head_dict.items(): + head_info += ' ' * 8 + f'{key}={value},\n' + head_info += ' ' * 8 + '...),\n' + return head_info + + +def build_neck_info(neck: Dict) -> str: + """Build the neck information string. + Args: + neck (Dict): Updated neck configuration. + + Returns: + str: Neck information string. + """ + if len(neck) > 0: + neck = neck.copy() + neck_info = ' ' * 4 + 'neck=dict(\n' + ' ' * 8 + \ + f'type=\'{neck.pop("type")}\',\n' + for key, value in neck.items(): + neck_info += ' ' * 8 + f'{key}={str(value)},\n' + neck_info += ' ' * 4 + '),\n' + else: + neck_info = '' + return neck_info diff --git a/mmpose/models/utils/ckpt_convert.py b/mmpose/models/utils/ckpt_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..05f5cdb4a3cdf32ac2b6b7a8888c5a772e582f14 --- /dev/null +++ b/mmpose/models/utils/ckpt_convert.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# This script consists of several convert functions which +# can modify the weights of model in original repo to be +# pre-trained weights. + +from collections import OrderedDict + +import torch + + +def pvt_convert(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + use_abs_pos_embed = False + use_conv_ffn = False + for k in ckpt.keys(): + if k.startswith('pos_embed'): + use_abs_pos_embed = True + if k.find('dwconv') >= 0: + use_conv_ffn = True + for k, v in ckpt.items(): + if k.startswith('head'): + continue + if k.startswith('norm.'): + continue + if k.startswith('cls_token'): + continue + if k.startswith('pos_embed'): + stage_i = int(k.replace('pos_embed', '')) + new_k = k.replace(f'pos_embed{stage_i}', + f'layers.{stage_i - 1}.1.0.pos_embed') + if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7 + new_v = v[:, 1:, :] # remove cls token + else: + new_v = v + elif k.startswith('patch_embed'): + stage_i = int(k.split('.')[0].replace('patch_embed', '')) + new_k = k.replace(f'patch_embed{stage_i}', + f'layers.{stage_i - 1}.0') + new_v = v + if 'proj.' in new_k: + new_k = new_k.replace('proj.', 'projection.') + elif k.startswith('block'): + stage_i = int(k.split('.')[0].replace('block', '')) + layer_i = int(k.split('.')[1]) + new_layer_i = layer_i + use_abs_pos_embed + new_k = k.replace(f'block{stage_i}.{layer_i}', + f'layers.{stage_i - 1}.1.{new_layer_i}') + new_v = v + if 'attn.q.' in new_k: + sub_item_k = k.replace('q.', 'kv.') + new_k = new_k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) + elif 'attn.kv.' in new_k: + continue + elif 'attn.proj.' in new_k: + new_k = new_k.replace('proj.', 'attn.out_proj.') + elif 'attn.sr.' in new_k: + new_k = new_k.replace('sr.', 'sr.') + elif 'mlp.' in new_k: + string = f'{new_k}-' + new_k = new_k.replace('mlp.', 'ffn.layers.') + if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + new_v = v.reshape((*v.shape, 1, 1)) + new_k = new_k.replace('fc1.', '0.') + new_k = new_k.replace('dwconv.dwconv.', '1.') + if use_conv_ffn: + new_k = new_k.replace('fc2.', '4.') + else: + new_k = new_k.replace('fc2.', '3.') + string += f'{new_k} {v.shape}-{new_v.shape}' + elif k.startswith('norm'): + stage_i = int(k[4]) + new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2') + new_v = v + else: + new_k = k + new_v = v + new_ckpt[new_k] = new_v + + return new_ckpt diff --git a/mmpose/models/utils/geometry.py b/mmpose/models/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..0ceadaec30cd2c9bb3fbada132e1ea674f2e8754 --- /dev/null +++ b/mmpose/models/utils/geometry.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn import functional as F + + +def rot6d_to_rotmat(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + + Based on Zhou et al., "On the Continuity of Rotation + Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def batch_rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion + -- size = [B, 3, 3] + """ + l2norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l2norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion + -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\ + norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(B, 3, 3) + return rotMat diff --git a/mmpose/models/utils/ops.py b/mmpose/models/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c94352647178e53e618e8fc7cba36fa7a9c0ad2 --- /dev/null +++ b/mmpose/models/utils/ops.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch.nn import functional as F + + +def resize(input: torch.Tensor, + size: Optional[Union[Tuple[int, int], torch.Size]] = None, + scale_factor: Optional[float] = None, + mode: str = 'nearest', + align_corners: Optional[bool] = None, + warning: bool = True) -> torch.Tensor: + """Resize a given input tensor using specified size or scale_factor. + + Args: + input (torch.Tensor): The input tensor to be resized. + size (Optional[Union[Tuple[int, int], torch.Size]]): The desired + output size. Defaults to None. + scale_factor (Optional[float]): The scaling factor for resizing. + Defaults to None. + mode (str): The interpolation mode. Defaults to 'nearest'. + align_corners (Optional[bool]): Determines whether to align the + corners when using certain interpolation modes. Defaults to None. + warning (bool): Whether to display a warning when the input and + output sizes are not ideal for alignment. Defaults to True. + + Returns: + torch.Tensor: The resized tensor. + """ + # Check if a warning should be displayed regarding input and output sizes + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would be more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + + # Convert torch.Size to tuple if necessary + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + + # Perform the resizing operation + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/mmpose/models/utils/realnvp.py b/mmpose/models/utils/realnvp.py new file mode 100644 index 0000000000000000000000000000000000000000..911953e8f9d1056d44a2d3538d750e89b9bd6a7a --- /dev/null +++ b/mmpose/models/utils/realnvp.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch import distributions + + +class RealNVP(nn.Module): + """RealNVP: a flow-based generative model + + `Density estimation using Real NVP + arXiv: `_. + + Code is modified from `the official implementation of RLE + `_. + + See also `real-nvp-pytorch + `_. + """ + + @staticmethod + def get_scale_net(): + """Get the scale model in a single invertable mapping.""" + return nn.Sequential( + nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64), + nn.LeakyReLU(), nn.Linear(64, 2), nn.Tanh()) + + @staticmethod + def get_trans_net(): + """Get the translation model in a single invertable mapping.""" + return nn.Sequential( + nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64), + nn.LeakyReLU(), nn.Linear(64, 2)) + + @property + def prior(self): + """The prior distribution.""" + return distributions.MultivariateNormal(self.loc, self.cov) + + def __init__(self): + super(RealNVP, self).__init__() + + self.register_buffer('loc', torch.zeros(2)) + self.register_buffer('cov', torch.eye(2)) + self.register_buffer( + 'mask', torch.tensor([[0, 1], [1, 0]] * 3, dtype=torch.float32)) + + self.s = torch.nn.ModuleList( + [self.get_scale_net() for _ in range(len(self.mask))]) + self.t = torch.nn.ModuleList( + [self.get_trans_net() for _ in range(len(self.mask))]) + self.init_weights() + + def init_weights(self): + """Initialization model weights.""" + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight, gain=0.01) + + def backward_p(self, x): + """Apply mapping form the data space to the latent space and calculate + the log determinant of the Jacobian matrix.""" + + log_det_jacob, z = x.new_zeros(x.shape[0]), x + for i in reversed(range(len(self.t))): + z_ = self.mask[i] * z + s = self.s[i](z_) * (1 - self.mask[i]) # torch.exp(s): betas + t = self.t[i](z_) * (1 - self.mask[i]) # gammas + z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_ + log_det_jacob -= s.sum(dim=1) + return z, log_det_jacob + + def log_prob(self, x): + """Calculate the log probability of given sample in data space.""" + + z, log_det = self.backward_p(x) + return self.prior.log_prob(z) + log_det diff --git a/mmpose/models/utils/regularizations.py b/mmpose/models/utils/regularizations.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c7449038066016f6efb60e126111ace962fe98 --- /dev/null +++ b/mmpose/models/utils/regularizations.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod, abstractproperty + +import torch + + +class PytorchModuleHook(metaclass=ABCMeta): + """Base class for PyTorch module hook registers. + + An instance of a subclass of PytorchModuleHook can be used to + register hook to a pytorch module using the `register` method like: + hook_register.register(module) + + Subclasses should add/overwrite the following methods: + - __init__ + - hook + - hook_type + """ + + @abstractmethod + def hook(self, *args, **kwargs): + """Hook function.""" + + @abstractproperty + def hook_type(self) -> str: + """Hook type Subclasses should overwrite this function to return a + string value in. + + {`forward`, `forward_pre`, `backward`} + """ + + def register(self, module): + """Register the hook function to the module. + + Args: + module (pytorch module): the module to register the hook. + + Returns: + handle (torch.utils.hooks.RemovableHandle): a handle to remove + the hook by calling handle.remove() + """ + assert isinstance(module, torch.nn.Module) + + if self.hook_type == 'forward': + h = module.register_forward_hook(self.hook) + elif self.hook_type == 'forward_pre': + h = module.register_forward_pre_hook(self.hook) + elif self.hook_type == 'backward': + h = module.register_backward_hook(self.hook) + else: + raise ValueError(f'Invalid hook type {self.hook}') + + return h + + +class WeightNormClipHook(PytorchModuleHook): + """Apply weight norm clip regularization. + + The module's parameter will be clip to a given maximum norm before each + forward pass. + + Args: + max_norm (float): The maximum norm of the parameter. + module_param_names (str|list): The parameter name (or name list) to + apply weight norm clip. + """ + + def __init__(self, max_norm=1.0, module_param_names='weight'): + self.module_param_names = module_param_names if isinstance( + module_param_names, list) else [module_param_names] + self.max_norm = max_norm + + @property + def hook_type(self): + return 'forward_pre' + + def hook(self, module, _input): + for name in self.module_param_names: + assert name in module._parameters, f'{name} is not a parameter' \ + f' of the module {type(module)}' + param = module._parameters[name] + + with torch.no_grad(): + m = param.norm().item() + if m > self.max_norm: + param.mul_(self.max_norm / (m + 1e-6)) diff --git a/mmpose/models/utils/rtmcc_block.py b/mmpose/models/utils/rtmcc_block.py new file mode 100644 index 0000000000000000000000000000000000000000..0e317376b2e969d01c31001748a22d6cef042f0f --- /dev/null +++ b/mmpose/models/utils/rtmcc_block.py @@ -0,0 +1,299 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import DropPath +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + + +def rope(x, dim): + """Applies Rotary Position Embedding to input tensor. + + Args: + x (torch.Tensor): Input tensor. + dim (int | list[int]): The spatial dimension(s) to apply + rotary position embedding. + + Returns: + torch.Tensor: The tensor after applying rotary position + embedding. + + Reference: + `RoFormer: Enhanced Transformer with Rotary + Position Embedding `_ + """ + shape = x.shape + if isinstance(dim, int): + dim = [dim] + + spatial_shape = [shape[i] for i in dim] + total_len = 1 + for i in spatial_shape: + total_len *= i + + position = torch.reshape( + torch.arange(total_len, dtype=torch.int, device=x.device), + spatial_shape) + + for i in range(dim[-1] + 1, len(shape) - 1, 1): + position = torch.unsqueeze(position, dim=-1) + + half_size = shape[-1] // 2 + freq_seq = -torch.arange( + half_size, dtype=torch.int, device=x.device) / float(half_size) + inv_freq = 10000**-freq_seq + + sinusoid = position[..., None] * inv_freq[None, None, :] + + sin = torch.sin(sinusoid) + cos = torch.cos(sinusoid) + x1, x2 = torch.chunk(x, 2, dim=-1) + + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + +class Scale(nn.Module): + """Scale vector by element multiplications. + + Args: + dim (int): The dimension of the scale vector. + init_value (float, optional): The initial value of the scale vector. + Defaults to 1.0. + trainable (bool, optional): Whether the scale vector is trainable. + Defaults to True. + """ + + def __init__(self, dim, init_value=1., trainable=True): + super().__init__() + self.scale = nn.Parameter( + init_value * torch.ones(dim), requires_grad=trainable) + + def forward(self, x): + """Forward function.""" + + return x * self.scale + + +class ScaleNorm(nn.Module): + """Scale Norm. + + Args: + dim (int): The dimension of the scale vector. + eps (float, optional): The minimum value in clamp. Defaults to 1e-5. + + Reference: + `Transformers without Tears: Improving the Normalization + of Self-Attention `_ + """ + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The tensor after applying scale norm. + """ + + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RTMCCBlock(nn.Module): + """Gated Attention Unit (GAU) in RTMBlock. + + Args: + num_token (int): The number of tokens. + in_token_dims (int): The input token dimension. + out_token_dims (int): The output token dimension. + expansion_factor (int, optional): The expansion factor of the + intermediate token dimension. Defaults to 2. + s (int, optional): The self-attention feature dimension. + Defaults to 128. + eps (float, optional): The minimum value in clamp. Defaults to 1e-5. + dropout_rate (float, optional): The dropout rate. Defaults to 0.0. + drop_path (float, optional): The drop path rate. Defaults to 0.0. + attn_type (str, optional): Type of attention which should be one of + the following options: + + - 'self-attn': Self-attention. + - 'cross-attn': Cross-attention. + + Defaults to 'self-attn'. + act_fn (str, optional): The activation function which should be one + of the following options: + + - 'ReLU': ReLU activation. + - 'SiLU': SiLU activation. + + Defaults to 'SiLU'. + bias (bool, optional): Whether to use bias in linear layers. + Defaults to False. + use_rel_bias (bool, optional): Whether to use relative bias. + Defaults to True. + pos_enc (bool, optional): Whether to use rotary position + embedding. Defaults to False. + + Reference: + `Transformer Quality in Linear Time + `_ + """ + + def __init__(self, + num_token, + in_token_dims, + out_token_dims, + expansion_factor=2, + s=128, + eps=1e-5, + dropout_rate=0., + drop_path=0., + attn_type='self-attn', + act_fn='SiLU', + bias=False, + use_rel_bias=True, + pos_enc=False): + + super(RTMCCBlock, self).__init__() + self.s = s + self.num_token = num_token + self.use_rel_bias = use_rel_bias + self.attn_type = attn_type + self.pos_enc = pos_enc + self.drop_path = DropPath(drop_path) \ + if drop_path > 0. else nn.Identity() + + self.e = int(in_token_dims * expansion_factor) + if use_rel_bias: + if attn_type == 'self-attn': + self.w = nn.Parameter( + torch.rand([2 * num_token - 1], dtype=torch.float)) + else: + self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float)) + self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float)) + self.o = nn.Linear(self.e, out_token_dims, bias=bias) + + if attn_type == 'self-attn': + self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias) + self.gamma = nn.Parameter(torch.rand((2, self.s))) + self.beta = nn.Parameter(torch.rand((2, self.s))) + else: + self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias) + self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias) + self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias) + nn.init.xavier_uniform_(self.k_fc.weight) + nn.init.xavier_uniform_(self.v_fc.weight) + + self.ln = ScaleNorm(in_token_dims, eps=eps) + + nn.init.xavier_uniform_(self.uv.weight) + + if act_fn == 'SiLU': + assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \ + 'SiLU activation requires PyTorch version >= 1.7' + + self.act_fn = nn.SiLU(True) + else: + self.act_fn = nn.ReLU(True) + + if in_token_dims == out_token_dims: + self.shortcut = True + self.res_scale = Scale(in_token_dims) + else: + self.shortcut = False + + self.sqrt_s = math.sqrt(s) + + self.dropout_rate = dropout_rate + + if dropout_rate > 0.: + self.dropout = nn.Dropout(dropout_rate) + + def rel_pos_bias(self, seq_len, k_len=None): + """Add relative position bias.""" + + if self.attn_type == 'self-attn': + t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len) + t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2) + r = (2 * seq_len - 1) // 2 + t = t[..., r:-r] + else: + a = rope(self.a.repeat(seq_len, 1), dim=0) + b = rope(self.b.repeat(k_len, 1), dim=0) + t = torch.bmm(a, b.permute(0, 2, 1)) + return t + + def _forward(self, inputs): + """GAU Forward function.""" + + if self.attn_type == 'self-attn': + x = inputs + else: + x, k, v = inputs + + x = self.ln(x) + + uv = self.uv(x) + + if self.attn_type == 'self-attn': + u, v, base = torch.split( + self.act_fn(uv), [self.e, self.e, self.s], dim=-1) + + base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta + + if self.pos_enc: + base = rope(base, dim=1) + + q, k = torch.unbind(base, dim=-2) + + else: + u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1) + + k = self.k_fc(k) + v = self.v_fc(v) + + if self.pos_enc: + q = rope(q, 1) + k = rope(k, 1) + + qk = torch.bmm(q, k.permute(0, 2, 1)) + + if self.use_rel_bias: + if self.attn_type == 'self-attn': + bias = self.rel_pos_bias(q.size(1)) + else: + bias = self.rel_pos_bias(q.size(1), k.size(1)) + qk += bias[:, :q.size(1), :k.size(1)] + + kernel = torch.square(F.relu(qk / self.sqrt_s)) + + if self.dropout_rate > 0.: + kernel = self.dropout(kernel) + + x = u * torch.bmm(kernel, v) + x = self.o(x) + + return x + + def forward(self, x): + """Forward function.""" + + if self.shortcut: + if self.attn_type == 'cross-attn': + res_shortcut = x[0] + else: + res_shortcut = x + main_branch = self.drop_path(self._forward(x)) + return self.res_scale(res_shortcut) + main_branch + else: + return self.drop_path(self._forward(x)) diff --git a/mmpose/models/utils/transformer.py b/mmpose/models/utils/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..103b9e9970a7b96b1d5ef288d9fbf5b787838a92 --- /dev/null +++ b/mmpose/models/utils/transformer.py @@ -0,0 +1,369 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils import to_2tuple + + +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len does not match H, W' + return x.transpose(1, 2).reshape(B, C, H, W).contiguous() + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super(AdaptivePadding, self).__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + """Get horizontal and vertical padding shapes.""" + + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + """Forward function.""" + + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d. + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels=3, + embed_dims=768, + conv_type='Conv2d', + kernel_size=16, + stride=16, + padding='corner', + dilation=1, + bias=True, + norm_cfg=None, + input_size=None, + init_cfg=None, + ): + super(PatchEmbed, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/mmpose/models/utils/tta.py b/mmpose/models/utils/tta.py new file mode 100644 index 0000000000000000000000000000000000000000..0add48a422a676e131fdc2cc31a9d7bfeadc382a --- /dev/null +++ b/mmpose/models/utils/tta.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def flip_heatmaps(heatmaps: Tensor, + flip_indices: Optional[List[int]] = None, + flip_mode: str = 'heatmap', + shift_heatmap: bool = True): + """Flip heatmaps for test-time augmentation. + + Args: + heatmaps (Tensor): The heatmaps to flip. Should be a tensor in shape + [B, C, H, W] + flip_indices (List[int]): The indices of each keypoint's symmetric + keypoint. Defaults to ``None`` + flip_mode (str): Specify the flipping mode. Options are: + + - ``'heatmap'``: horizontally flip the heatmaps and swap heatmaps + of symmetric keypoints according to ``flip_indices`` + - ``'udp_combined'``: similar to ``'heatmap'`` mode but further + flip the x_offset values + - ``'offset'``: horizontally flip the offset fields and swap + heatmaps of symmetric keypoints according to + ``flip_indices``. x_offset values are also reversed + shift_heatmap (bool): Shift the flipped heatmaps to align with the + original heatmaps and improve accuracy. Defaults to ``True`` + + Returns: + Tensor: flipped heatmaps in shape [B, C, H, W] + """ + + if flip_mode == 'heatmap': + heatmaps = heatmaps.flip(-1) + if flip_indices is not None: + assert len(flip_indices) == heatmaps.shape[1] + heatmaps = heatmaps[:, flip_indices] + elif flip_mode == 'udp_combined': + B, C, H, W = heatmaps.shape + heatmaps = heatmaps.view(B, C // 3, 3, H, W) + heatmaps = heatmaps.flip(-1) + if flip_indices is not None: + assert len(flip_indices) == C // 3 + heatmaps = heatmaps[:, flip_indices] + heatmaps[:, :, 1] = -heatmaps[:, :, 1] + heatmaps = heatmaps.view(B, C, H, W) + + elif flip_mode == 'offset': + B, C, H, W = heatmaps.shape + heatmaps = heatmaps.view(B, C // 2, -1, H, W) + heatmaps = heatmaps.flip(-1) + if flip_indices is not None: + assert len(flip_indices) == C // 2 + heatmaps = heatmaps[:, flip_indices] + heatmaps[:, :, 0] = -heatmaps[:, :, 0] + heatmaps = heatmaps.view(B, C, H, W) + + else: + raise ValueError(f'Invalid flip_mode value "{flip_mode}"') + + if shift_heatmap: + # clone data to avoid unexpected in-place operation when using CPU + heatmaps[..., 1:] = heatmaps[..., :-1].clone() + + return heatmaps + + +def flip_vectors(x_labels: Tensor, y_labels: Tensor, flip_indices: List[int]): + """Flip instance-level labels in specific axis for test-time augmentation. + + Args: + x_labels (Tensor): The vector labels in x-axis to flip. Should be + a tensor in shape [B, C, Wx] + y_labels (Tensor): The vector labels in y-axis to flip. Should be + a tensor in shape [B, C, Wy] + flip_indices (List[int]): The indices of each keypoint's symmetric + keypoint + """ + assert x_labels.ndim == 3 and y_labels.ndim == 3 + assert len(flip_indices) == x_labels.shape[1] and len( + flip_indices) == y_labels.shape[1] + x_labels = x_labels[:, flip_indices].flip(-1) + y_labels = y_labels[:, flip_indices] + + return x_labels, y_labels + + +def flip_coordinates(coords: Tensor, flip_indices: List[int], + shift_coords: bool, input_size: Tuple[int, int]): + """Flip normalized coordinates for test-time augmentation. + + Args: + coords (Tensor): The coordinates to flip. Should be a tensor in shape + [B, K, D] + flip_indices (List[int]): The indices of each keypoint's symmetric + keypoint + shift_coords (bool): Shift the flipped coordinates to align with the + original coordinates and improve accuracy. Defaults to ``True`` + input_size (Tuple[int, int]): The size of input image in [w, h] + """ + assert coords.ndim == 3 + assert len(flip_indices) == coords.shape[1] + + coords[:, :, 0] = 1.0 - coords[:, :, 0] + + if shift_coords: + img_width = input_size[0] + coords[:, :, 0] -= 1.0 / img_width + + coords = coords[:, flip_indices] + return coords + + +def aggregate_heatmaps(heatmaps: List[Tensor], + size: Optional[Tuple[int, int]], + align_corners: bool = False, + mode: str = 'average'): + """Aggregate multiple heatmaps. + + Args: + heatmaps (List[Tensor]): Multiple heatmaps to aggregate. Each should + be in shape (B, C, H, W) + size (Tuple[int, int], optional): The target size in (w, h). All + heatmaps will be resized to the target size. If not given, the + first heatmap tensor's width and height will be used as the target + size. Defaults to ``None`` + align_corners (bool): Whether align corners when resizing heatmaps. + Defaults to ``False`` + mode (str): Aggregation mode in one of the following: + + - ``'average'``: Get average of heatmaps. All heatmaps mush have + the same channel number + - ``'concat'``: Concate the heatmaps at the channel dim + """ + + if mode not in {'average', 'concat'}: + raise ValueError(f'Invalid aggregation mode `{mode}`') + + if size is None: + h, w = heatmaps[0].shape[2:4] + else: + w, h = size + + for i, _heatmaps in enumerate(heatmaps): + assert _heatmaps.ndim == 4 + if mode == 'average': + assert _heatmaps.shape[:2] == heatmaps[0].shape[:2] + else: + assert _heatmaps.shape[0] == heatmaps[0].shape[0] + + if _heatmaps.shape[2:4] != (h, w): + heatmaps[i] = F.interpolate( + _heatmaps, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + + if mode == 'average': + output = sum(heatmaps).div(len(heatmaps)) + elif mode == 'concat': + output = torch.cat(heatmaps, dim=1) + else: + raise ValueError() + + return output diff --git a/mmpose/registry.py b/mmpose/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b8d17c4c33dc78b457608d1b3951401cf55d52 --- /dev/null +++ b/mmpose/registry.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMPose provides following registry nodes to support using modules across +projects. + +Each node is a child of the root registry in MMEngine. +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# Registries For Runner and the related +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['mmpose.engine.hooks']) + +# Registries For Data and the related +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['mmpose.datasets']) +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmpose.datasets.samplers']) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmpose.datasets.transforms']) + +# manage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmpose.models']) +# manage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmpose.models']) +# manage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmpose.models']) +# manage all kinds of batch augmentations like Mixup and CutMix. +BATCH_AUGMENTS = Registry('batch augment', locations=['mmpose.models']) + +# Registries For Optimizer and the related +# manage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry( + 'optimizer', parent=MMENGINE_OPTIMIZERS, locations=['mmpose.engine']) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry( + 'optimizer_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmpose.engine']) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmpose.engine.optim_wrappers']) +# manage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmpose.engine']) + +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['mmpose.evaluation.metrics']) +# manage all kinds of evaluators +EVALUATORS = Registry( + 'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmpose.evaluation']) + +# manage task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmpose.models']) + +# Registries For Visualizer and the related +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmpose.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmpose.visualization']) + +# manage all kinds log processors +LOG_PROCESSORS = Registry( + 'log processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmpose.visualization']) + +# manager keypoint encoder/decoder +KEYPOINT_CODECS = Registry('KEYPOINT_CODECS', locations=['mmpose.codecs']) + +# manage inferencer +INFERENCERS = Registry( + 'inferencer', + parent=MMENGINE_INFERENCERS, + locations=['mmpose.apis.inferencers']) diff --git a/mmpose/structures/__init__.py b/mmpose/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4384af1cd0f7cfac9a5f8d26e34caf4c78b9fc9 --- /dev/null +++ b/mmpose/structures/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bbox import (bbox_cs2xywh, bbox_cs2xyxy, bbox_xywh2cs, bbox_xywh2xyxy, + bbox_xyxy2cs, bbox_xyxy2xywh, flip_bbox, + get_udp_warp_matrix, get_warp_matrix) +from .keypoint import flip_keypoints +from .multilevel_pixel_data import MultilevelPixelData +from .pose_data_sample import PoseDataSample +from .utils import merge_data_samples, revert_heatmap, split_instances + +__all__ = [ + 'PoseDataSample', 'MultilevelPixelData', 'bbox_cs2xywh', 'bbox_cs2xyxy', + 'bbox_xywh2cs', 'bbox_xywh2xyxy', 'bbox_xyxy2cs', 'bbox_xyxy2xywh', + 'flip_bbox', 'get_udp_warp_matrix', 'get_warp_matrix', 'flip_keypoints', + 'merge_data_samples', 'revert_heatmap', 'split_instances' +] diff --git a/mmpose/structures/__pycache__/__init__.cpython-38.pyc b/mmpose/structures/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd8cb302514cdb3d993770724b9ca27b4861308 Binary files /dev/null and b/mmpose/structures/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/structures/__pycache__/multilevel_pixel_data.cpython-38.pyc b/mmpose/structures/__pycache__/multilevel_pixel_data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9572277296bf43834f212fd83c9fd46effabf2a Binary files /dev/null and b/mmpose/structures/__pycache__/multilevel_pixel_data.cpython-38.pyc differ diff --git a/mmpose/structures/__pycache__/pose_data_sample.cpython-38.pyc b/mmpose/structures/__pycache__/pose_data_sample.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b16e8532e78307a849c953394e25e19e7c5642 Binary files /dev/null and b/mmpose/structures/__pycache__/pose_data_sample.cpython-38.pyc differ diff --git a/mmpose/structures/__pycache__/utils.cpython-38.pyc b/mmpose/structures/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ef6db350956f3f3a2e7486126e480dc2e482ffe Binary files /dev/null and b/mmpose/structures/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpose/structures/bbox/__init__.py b/mmpose/structures/bbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e723918ca4b985b9e5ff3aa36b2a5c3bd2d700 --- /dev/null +++ b/mmpose/structures/bbox/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .transforms import (bbox_cs2xywh, bbox_cs2xyxy, bbox_xywh2cs, + bbox_xywh2xyxy, bbox_xyxy2cs, bbox_xyxy2xywh, + flip_bbox, get_udp_warp_matrix, get_warp_matrix) + +__all__ = [ + 'bbox_cs2xywh', 'bbox_cs2xyxy', 'bbox_xywh2cs', 'bbox_xywh2xyxy', + 'bbox_xyxy2cs', 'bbox_xyxy2xywh', 'flip_bbox', 'get_udp_warp_matrix', + 'get_warp_matrix' +] diff --git a/mmpose/structures/bbox/__pycache__/__init__.cpython-38.pyc b/mmpose/structures/bbox/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..388945bffedb893c5e22ed4fbb4d8d0fdc842aaa Binary files /dev/null and b/mmpose/structures/bbox/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/structures/bbox/__pycache__/transforms.cpython-38.pyc b/mmpose/structures/bbox/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89924501b431dc10506e465976bba57c8d74b061 Binary files /dev/null and b/mmpose/structures/bbox/__pycache__/transforms.cpython-38.pyc differ diff --git a/mmpose/structures/bbox/transforms.py b/mmpose/structures/bbox/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..027ac0717bfa3fbafa9980b523b38496f234e474 --- /dev/null +++ b/mmpose/structures/bbox/transforms.py @@ -0,0 +1,361 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Tuple + +import cv2 +import numpy as np + + +def bbox_xyxy2xywh(bbox_xyxy: np.ndarray) -> np.ndarray: + """Transform the bbox format from x1y1x2y2 to xywh. + + Args: + bbox_xyxy (np.ndarray): Bounding boxes (with scores), shaped (n, 4) or + (n, 5). (left, top, right, bottom, [score]) + + Returns: + np.ndarray: Bounding boxes (with scores), + shaped (n, 4) or (n, 5). (left, top, width, height, [score]) + """ + bbox_xywh = bbox_xyxy.copy() + bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0] + bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1] + + return bbox_xywh + + +def bbox_xywh2xyxy(bbox_xywh: np.ndarray) -> np.ndarray: + """Transform the bbox format from xywh to x1y1x2y2. + + Args: + bbox_xywh (ndarray): Bounding boxes (with scores), + shaped (n, 4) or (n, 5). (left, top, width, height, [score]) + Returns: + np.ndarray: Bounding boxes (with scores), shaped (n, 4) or + (n, 5). (left, top, right, bottom, [score]) + """ + bbox_xyxy = bbox_xywh.copy() + bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0] + bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1] + + return bbox_xyxy + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def bbox_xywh2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (x, y, h, w) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + x, y, w, h = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x + w * 0.5, y + h * 0.5]) + scale = np.hstack([w, h]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def bbox_cs2xyxy(center: np.ndarray, + scale: np.ndarray, + padding: float = 1.) -> np.ndarray: + """Transform the bbox format from (center, scale) to (x,y,w,h). + + Args: + center (ndarray): BBox center (x, y) in shape (2,) or (n, 2) + scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + ndarray[float32]: BBox (x, y, w, h) in shape (4, ) or (n, 4) + """ + + dim = center.ndim + assert scale.ndim == dim + + if dim == 1: + center = center[None, :] + scale = scale[None, :] + + wh = scale / padding + xy = center - 0.5 * wh + bbox = np.hstack((xy, xy + wh)) + + if dim == 1: + bbox = bbox[0] + + return bbox + + +def bbox_cs2xywh(center: np.ndarray, + scale: np.ndarray, + padding: float = 1.) -> np.ndarray: + """Transform the bbox format from (center, scale) to (x,y,w,h). + + Args: + center (ndarray): BBox center (x, y) in shape (2,) or (n, 2) + scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + ndarray[float32]: BBox (x, y, w, h) in shape (4, ) or (n, 4) + """ + + dim = center.ndim + assert scale.ndim == dim + + if dim == 1: + center = center[None, :] + scale = scale[None, :] + + wh = scale / padding + xy = center - 0.5 * wh + bbox = np.hstack((xy, wh)) + + if dim == 1: + bbox = bbox[0] + + return bbox + + +def flip_bbox(bbox: np.ndarray, + image_size: Tuple[int, int], + bbox_format: str = 'xywh', + direction: str = 'horizontal') -> np.ndarray: + """Flip the bbox in the given direction. + + Args: + bbox (np.ndarray): The bounding boxes. The shape should be (..., 4) + if ``bbox_format`` is ``'xyxy'`` or ``'xywh'``, and (..., 2) if + ``bbox_format`` is ``'center'`` + image_size (tuple): The image shape in [w, h] + bbox_format (str): The bbox format. Options are ``'xywh'``, ``'xyxy'`` + and ``'center'``. + direction (str): The flip direction. Options are ``'horizontal'``, + ``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'`` + + Returns: + np.ndarray: The flipped bounding boxes. + """ + direction_options = {'horizontal', 'vertical', 'diagonal'} + assert direction in direction_options, ( + f'Invalid flipping direction "{direction}". ' + f'Options are {direction_options}') + + format_options = {'xywh', 'xyxy', 'center'} + assert bbox_format in format_options, ( + f'Invalid bbox format "{bbox_format}". ' + f'Options are {format_options}') + + bbox_flipped = bbox.copy() + w, h = image_size + + # TODO: consider using "integer corner" coordinate system + if direction == 'horizontal': + if bbox_format == 'xywh' or bbox_format == 'center': + bbox_flipped[..., 0] = w - bbox[..., 0] - 1 + elif bbox_format == 'xyxy': + bbox_flipped[..., ::2] = w - bbox[..., ::2] - 1 + elif direction == 'vertical': + if bbox_format == 'xywh' or bbox_format == 'center': + bbox_flipped[..., 1] = h - bbox[..., 1] - 1 + elif bbox_format == 'xyxy': + bbox_flipped[..., 1::2] = h - bbox[..., 1::2] - 1 + elif direction == 'diagonal': + if bbox_format == 'xywh' or bbox_format == 'center': + bbox_flipped[..., :2] = [w, h] - bbox[..., :2] - 1 + elif bbox_format == 'xyxy': + bbox_flipped[...] = [w, h, w, h] - bbox - 1 + + return bbox_flipped + + +def get_udp_warp_matrix( + center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], +) -> np.ndarray: + """Calculate the affine transformation matrix under the unbiased + constraint. See `UDP (CVPR 2020)`_ for details. + + Note: + + - The bbox number: N + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (tuple): Size ([w, h]) of the output image + + Returns: + np.ndarray: A 2x3 transformation matrix + + .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 + """ + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + + input_size = center * 2 + rot_rad = np.deg2rad(rot) + warp_mat = np.zeros((2, 3), dtype=np.float32) + scale_x = (output_size[0] - 1) / scale[0] + scale_y = (output_size[1] - 1) / scale[1] + warp_mat[0, 0] = math.cos(rot_rad) * scale_x + warp_mat[0, 1] = -math.sin(rot_rad) * scale_x + warp_mat[0, 2] = scale_x * (-0.5 * input_size[0] * math.cos(rot_rad) + + 0.5 * input_size[1] * math.sin(rot_rad) + + 0.5 * scale[0]) + warp_mat[1, 0] = math.sin(rot_rad) * scale_y + warp_mat[1, 1] = math.cos(rot_rad) * scale_y + warp_mat[1, 2] = scale_y * (-0.5 * input_size[0] * math.sin(rot_rad) - + 0.5 * input_size[1] * math.cos(rot_rad) + + 0.5 * scale[1]) + return warp_mat + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + return warp_mat + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c diff --git a/mmpose/structures/keypoint/__init__.py b/mmpose/structures/keypoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d5a24c7a8f8f08b4ef8f37b8628249e3734ba7 --- /dev/null +++ b/mmpose/structures/keypoint/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .transforms import flip_keypoints + +__all__ = ['flip_keypoints'] diff --git a/mmpose/structures/keypoint/__pycache__/__init__.cpython-38.pyc b/mmpose/structures/keypoint/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f346107454a95a3b9c660968a43b9b88beb2ae8a Binary files /dev/null and b/mmpose/structures/keypoint/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/structures/keypoint/__pycache__/transforms.cpython-38.pyc b/mmpose/structures/keypoint/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22560b454c20a0244b24d8e9c500b33b0b9a7eeb Binary files /dev/null and b/mmpose/structures/keypoint/__pycache__/transforms.cpython-38.pyc differ diff --git a/mmpose/structures/keypoint/transforms.py b/mmpose/structures/keypoint/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..99adaa130619e2e6040d5cc2d055ea09e99bec23 --- /dev/null +++ b/mmpose/structures/keypoint/transforms.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np + + +def flip_keypoints(keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray], + image_size: Tuple[int, int], + flip_indices: List[int], + direction: str = 'horizontal' + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Flip keypoints in the given direction. + + Note: + + - keypoint number: K + - keypoint dimension: D + + Args: + keypoints (np.ndarray): Keypoints in shape (..., K, D) + keypoints_visible (np.ndarray, optional): The visibility of keypoints + in shape (..., K, 1). Set ``None`` if the keypoint visibility is + unavailable + image_size (tuple): The image shape in [w, h] + flip_indices (List[int]): The indices of each keypoint's symmetric + keypoint + direction (str): The flip direction. Options are ``'horizontal'``, + ``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'`` + + Returns: + tuple: + - keypoints_flipped (np.ndarray): Flipped keypoints in shape + (..., K, D) + - keypoints_visible_flipped (np.ndarray, optional): Flipped keypoints' + visibility in shape (..., K, 1). Return ``None`` if the input + ``keypoints_visible`` is ``None`` + """ + + assert keypoints.shape[:-1] == keypoints_visible.shape, ( + f'Mismatched shapes of keypoints {keypoints.shape} and ' + f'keypoints_visible {keypoints_visible.shape}') + + direction_options = {'horizontal', 'vertical', 'diagonal'} + assert direction in direction_options, ( + f'Invalid flipping direction "{direction}". ' + f'Options are {direction_options}') + + # swap the symmetric keypoint pairs + if direction == 'horizontal' or direction == 'vertical': + keypoints = keypoints[..., flip_indices, :] + if keypoints_visible is not None: + keypoints_visible = keypoints_visible[..., flip_indices] + + # flip the keypoints + w, h = image_size + if direction == 'horizontal': + keypoints[..., 0] = w - 1 - keypoints[..., 0] + elif direction == 'vertical': + keypoints[..., 1] = h - 1 - keypoints[..., 1] + else: + keypoints = [w, h] - keypoints - 1 + + return keypoints, keypoints_visible diff --git a/mmpose/structures/multilevel_pixel_data.py b/mmpose/structures/multilevel_pixel_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bea191e7297c233cc129f2da09ab5a4c6793fa0f --- /dev/null +++ b/mmpose/structures/multilevel_pixel_data.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import abc +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union + +import numpy as np +import torch +from mmengine.structures import BaseDataElement, PixelData +from mmengine.utils import is_list_of + +IndexType = Union[str, slice, int, list, torch.LongTensor, + torch.cuda.LongTensor, torch.BoolTensor, + torch.cuda.BoolTensor, np.ndarray] + + +class MultilevelPixelData(BaseDataElement): + """Data structure for multi-level pixel-wise annotations or predictions. + + All data items in ``data_fields`` of ``MultilevelPixelData`` are lists + of np.ndarray or torch.Tensor, and should meet the following requirements: + + - Have the same length, which is the number of levels + - At each level, the data should have 3 dimensions in order of channel, + height and weight + - At each level, the data should have the same height and weight + + Examples: + >>> metainfo = dict(num_keypoints=17) + >>> sizes = [(64, 48), (128, 96), (256, 192)] + >>> heatmaps = [np.random.rand(17, h, w) for h, w in sizes] + >>> masks = [torch.rand(1, h, w) for h, w in sizes] + >>> data = MultilevelPixelData(metainfo=metainfo, + ... heatmaps=heatmaps, + ... masks=masks) + + >>> # get data item + >>> heatmaps = data.heatmaps # A list of 3 numpy.ndarrays + >>> masks = data.masks # A list of 3 torch.Tensors + + >>> # get level + >>> data_l0 = data[0] # PixelData with fields 'heatmaps' and 'masks' + >>> data.nlevel + 3 + + >>> # get shape + >>> data.shape + ((64, 48), (128, 96), (256, 192)) + + >>> # set + >>> offset_maps = [torch.rand(2, h, w) for h, w in sizes] + >>> data.offset_maps = offset_maps + """ + + def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: + object.__setattr__(self, '_nlevel', None) + super().__init__(metainfo=metainfo, **kwargs) + + @property + def nlevel(self): + """Return the level number. + + Returns: + Optional[int]: The level number, or ``None`` if the data has not + been assigned. + """ + return self._nlevel + + def __getitem__(self, item: Union[int, str, list, + slice]) -> Union[PixelData, Sequence]: + if isinstance(item, int): + if self.nlevel is None or item >= self.nlevel: + raise IndexError( + f'Lcale index {item} out of range ({self.nlevel})') + return self.get(f'_level_{item}') + + if isinstance(item, str): + if item not in self: + raise KeyError(item) + return getattr(self, item) + + # TODO: support indexing by list and slice over levels + raise NotImplementedError( + f'{self.__class__.__name__} does not support index type ' + f'{type(item)}') + + def levels(self) -> List[PixelData]: + if self.nlevel: + return list(self[i] for i in range(self.nlevel)) + return [] + + @property + def shape(self) -> Optional[Tuple[Tuple]]: + """Get the shape of multi-level pixel data. + + Returns: + Optional[tuple]: A tuple of data shape at each level, or ``None`` + if the data has not been assigned. + """ + if self.nlevel is None: + return None + + return tuple(level.shape for level in self.levels()) + + def set_data(self, data: dict) -> None: + """Set or change key-value pairs in ``data_field`` by parameter + ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + assert isinstance(data, + dict), f'meta should be a `dict` but got {data}' + for k, v in data.items(): + self.set_field(v, k, field_type='data') + + def set_field(self, + value: Any, + name: str, + dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, + field_type: str = 'data') -> None: + """Special method for set union field, used as property.setter + functions.""" + assert field_type in ['metainfo', 'data'] + if dtype is not None: + assert isinstance( + value, + dtype), f'{value} should be a {dtype} but got {type(value)}' + + if name.startswith('_level_'): + raise AttributeError( + f'Cannot set {name} to be a field because the pattern ' + '<_level_{n}> is reserved for inner data field') + + if field_type == 'metainfo': + if name in self._data_fields: + raise AttributeError( + f'Cannot set {name} to be a field of metainfo ' + f'because {name} is already a data field') + self._metainfo_fields.add(name) + + else: + if name in self._metainfo_fields: + raise AttributeError( + f'Cannot set {name} to be a field of data ' + f'because {name} is already a metainfo field') + + if not isinstance(value, abc.Sequence): + raise TypeError( + 'The value should be a sequence (of numpy.ndarray or' + f'torch.Tesnor), but got a {type(value)}') + + if len(value) == 0: + raise ValueError('Setting empty value is not allowed') + + if not isinstance(value[0], (torch.Tensor, np.ndarray)): + raise TypeError( + 'The value should be a sequence of numpy.ndarray or' + f'torch.Tesnor, but got a sequence of {type(value[0])}') + + if self.nlevel is not None: + assert len(value) == self.nlevel, ( + f'The length of the value ({len(value)}) should match the' + f'number of the levels ({self.nlevel})') + else: + object.__setattr__(self, '_nlevel', len(value)) + for i in range(self.nlevel): + object.__setattr__(self, f'_level_{i}', PixelData()) + + for i, v in enumerate(value): + self[i].set_field(v, name, field_type='data') + + self._data_fields.add(name) + + object.__setattr__(self, name, value) + + def __delattr__(self, item: str): + """delete the item in dataelement. + + Args: + item (str): The key to delete. + """ + if item in ('_metainfo_fields', '_data_fields'): + raise AttributeError(f'{item} has been used as a ' + 'private attribute, which is immutable. ') + + if item in self._metainfo_fields: + super().__delattr__(item) + else: + for level in self.levels(): + level.__delattr__(item) + self._data_fields.remove(item) + + def __getattr__(self, name): + if name in {'_data_fields', '_metainfo_fields' + } or name not in self._data_fields: + raise AttributeError( + f'\'{self.__class__.__name__}\' object has no attribute ' + f'\'{name}\'') + + return [getattr(level, name) for level in self.levels()] + + def pop(self, *args) -> Any: + """pop property in data and metainfo as the same as python.""" + assert len(args) < 3, '``pop`` get more than 2 arguments' + name = args[0] + if name in self._metainfo_fields: + self._metainfo_fields.remove(name) + return self.__dict__.pop(*args) + + elif name in self._data_fields: + self._data_fields.remove(name) + return [level.pop(*args) for level in self.levels()] + + # with default value + elif len(args) == 2: + return args[1] + else: + # don't just use 'self.__dict__.pop(*args)' for only popping key in + # metainfo or data + raise KeyError(f'{args[0]} is not contained in metainfo or data') + + def _convert(self, apply_to: Type, + func: Callable[[Any], Any]) -> 'MultilevelPixelData': + """Convert data items with the given function. + + Args: + apply_to (Type): The type of data items to apply the conversion + func (Callable): The conversion function that takes a data item + as the input and return the converted result + + Returns: + MultilevelPixelData: the converted data element. + """ + new_data = self.new() + for k, v in self.items(): + if is_list_of(v, apply_to): + v = [func(_v) for _v in v] + data = {k: v} + new_data.set_data(data) + return new_data + + def cpu(self) -> 'MultilevelPixelData': + """Convert all tensors to CPU in data.""" + return self._convert(apply_to=torch.Tensor, func=lambda x: x.cpu()) + + def cuda(self) -> 'MultilevelPixelData': + """Convert all tensors to GPU in data.""" + return self._convert(apply_to=torch.Tensor, func=lambda x: x.cuda()) + + def detach(self) -> 'MultilevelPixelData': + """Detach all tensors in data.""" + return self._convert(apply_to=torch.Tensor, func=lambda x: x.detach()) + + def numpy(self) -> 'MultilevelPixelData': + """Convert all tensor to np.narray in data.""" + return self._convert( + apply_to=torch.Tensor, func=lambda x: x.detach().cpu().numpy()) + + def to_tensor(self) -> 'MultilevelPixelData': + """Convert all tensor to np.narray in data.""" + return self._convert( + apply_to=np.ndarray, func=lambda x: torch.from_numpy(x)) + + # Tensor-like methods + def to(self, *args, **kwargs) -> 'MultilevelPixelData': + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v in self.items(): + if hasattr(v[0], 'to'): + v = [v_.to(*args, **kwargs) for v_ in v] + data = {k: v} + new_data.set_data(data) + return new_data diff --git a/mmpose/structures/pose_data_sample.py b/mmpose/structures/pose_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1d69034e319c6c43870f56e06661c536adf5e2 --- /dev/null +++ b/mmpose/structures/pose_data_sample.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +from mmengine.structures import BaseDataElement, InstanceData, PixelData + +from mmpose.structures import MultilevelPixelData + + +class PoseDataSample(BaseDataElement): + """The base data structure of MMPose that is used as the interface between + modules. + + The attributes of ``PoseDataSample`` includes: + + - ``gt_instances``(InstanceData): Ground truth of instances with + keypoint annotations + - ``pred_instances``(InstanceData): Instances with keypoint + predictions + - ``gt_fields``(PixelData): Ground truth of spatial distribution + annotations like keypoint heatmaps and part affine fields (PAF) + - ``pred_fields``(PixelData): Predictions of spatial distributions + + Examples: + >>> import torch + >>> from mmengine.structures import InstanceData, PixelData + >>> from mmpose.structures import PoseDataSample + + >>> pose_meta = dict(img_shape=(800, 1216), + ... crop_size=(256, 192), + ... heatmap_size=(64, 48)) + >>> gt_instances = InstanceData() + >>> gt_instances.bboxes = torch.rand((1, 4)) + >>> gt_instances.keypoints = torch.rand((1, 17, 2)) + >>> gt_instances.keypoints_visible = torch.rand((1, 17, 1)) + >>> gt_fields = PixelData() + >>> gt_fields.heatmaps = torch.rand((17, 64, 48)) + + >>> data_sample = PoseDataSample(gt_instances=gt_instances, + ... gt_fields=gt_fields, + ... metainfo=pose_meta) + >>> assert 'img_shape' in data_sample + >>> len(data_sample.gt_intances) + 1 + """ + + @property + def gt_instances(self) -> InstanceData: + return self._gt_instances + + @gt_instances.setter + def gt_instances(self, value: InstanceData): + self.set_field(value, '_gt_instances', dtype=InstanceData) + + @gt_instances.deleter + def gt_instances(self): + del self._gt_instances + + @property + def gt_instance_labels(self) -> InstanceData: + return self._gt_instance_labels + + @gt_instance_labels.setter + def gt_instance_labels(self, value: InstanceData): + self.set_field(value, '_gt_instance_labels', dtype=InstanceData) + + @gt_instance_labels.deleter + def gt_instance_labels(self): + del self._gt_instance_labels + + @property + def pred_instances(self) -> InstanceData: + return self._pred_instances + + @pred_instances.setter + def pred_instances(self, value: InstanceData): + self.set_field(value, '_pred_instances', dtype=InstanceData) + + @pred_instances.deleter + def pred_instances(self): + del self._pred_instances + + @property + def gt_fields(self) -> Union[PixelData, MultilevelPixelData]: + return self._gt_fields + + @gt_fields.setter + def gt_fields(self, value: Union[PixelData, MultilevelPixelData]): + self.set_field(value, '_gt_fields', dtype=type(value)) + + @gt_fields.deleter + def gt_fields(self): + del self._gt_fields + + @property + def pred_fields(self) -> PixelData: + return self._pred_heatmaps + + @pred_fields.setter + def pred_fields(self, value: PixelData): + self.set_field(value, '_pred_heatmaps', dtype=PixelData) + + @pred_fields.deleter + def pred_fields(self): + del self._pred_heatmaps diff --git a/mmpose/structures/utils.py b/mmpose/structures/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..882cda86037bd2c5d68b938df23dfe91b4007957 --- /dev/null +++ b/mmpose/structures/utils.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List + +import cv2 +import numpy as np +import torch +from mmengine.structures import InstanceData, PixelData +from mmengine.utils import is_list_of + +from .bbox.transforms import get_warp_matrix +from .pose_data_sample import PoseDataSample + + +def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: + """Merge the given data samples into a single data sample. + + This function can be used to merge the top-down predictions with + bboxes from the same image. The merged data sample will contain all + instances from the input data samples, and the identical metainfo with + the first input data sample. + + Args: + data_samples (List[:obj:`PoseDataSample`]): The data samples to + merge + + Returns: + PoseDataSample: The merged data sample. + """ + + if not is_list_of(data_samples, PoseDataSample): + raise ValueError('Invalid input type, should be a list of ' + ':obj:`PoseDataSample`') + + if len(data_samples) == 0: + warnings.warn('Try to merge an empty list of data samples.') + return PoseDataSample() + + merged = PoseDataSample(metainfo=data_samples[0].metainfo) + + if 'gt_instances' in data_samples[0]: + merged.gt_instances = InstanceData.cat( + [d.gt_instances for d in data_samples]) + + if 'pred_instances' in data_samples[0]: + merged.pred_instances = InstanceData.cat( + [d.pred_instances for d in data_samples]) + + if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[ + 0].pred_fields: + reverted_heatmaps = [ + revert_heatmap(data_sample.pred_fields.heatmaps, + data_sample.gt_instances.bbox_centers, + data_sample.gt_instances.bbox_scales, + data_sample.ori_shape) + for data_sample in data_samples + ] + + merged_heatmaps = np.max(reverted_heatmaps, axis=0) + pred_fields = PixelData() + pred_fields.set_data(dict(heatmaps=merged_heatmaps)) + merged.pred_fields = pred_fields + + if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[ + 0].gt_fields: + reverted_heatmaps = [ + revert_heatmap(data_sample.gt_fields.heatmaps, + data_sample.gt_instances.bbox_centers, + data_sample.gt_instances.bbox_scales, + data_sample.ori_shape) + for data_sample in data_samples + ] + + merged_heatmaps = np.max(reverted_heatmaps, axis=0) + gt_fields = PixelData() + gt_fields.set_data(dict(heatmaps=merged_heatmaps)) + merged.gt_fields = gt_fields + + return merged + + +def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape): + """Revert predicted heatmap on the original image. + + Args: + heatmap (np.ndarray or torch.tensor): predicted heatmap. + bbox_center (np.ndarray): bounding box center coordinate. + bbox_scale (np.ndarray): bounding box scale. + img_shape (tuple or list): size of original image. + """ + if torch.is_tensor(heatmap): + heatmap = heatmap.cpu().detach().numpy() + + ndim = heatmap.ndim + # [K, H, W] -> [H, W, K] + if ndim == 3: + heatmap = heatmap.transpose(1, 2, 0) + + hm_h, hm_w = heatmap.shape[:2] + img_h, img_w = img_shape + warp_mat = get_warp_matrix( + bbox_center.reshape((2, )), + bbox_scale.reshape((2, )), + rot=0, + output_size=(hm_w, hm_h), + inv=True) + + heatmap = cv2.warpAffine( + heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR) + + # [H, W, K] -> [K, H, W] + if ndim == 3: + heatmap = heatmap.transpose(2, 0, 1) + + return heatmap + + +def split_instances(instances: InstanceData) -> List[InstanceData]: + """Convert instances into a list where each element is a dict that contains + information about one instance.""" + results = [] + + # return an empty list if there is no instance detected by the model + if instances is None: + return results + + for i in range(len(instances.keypoints)): + result = dict( + keypoints=instances.keypoints[i].tolist(), + keypoint_scores=instances.keypoint_scores[i].tolist(), + ) + if 'bboxes' in instances: + result['bbox'] = instances.bboxes[i].tolist(), + if 'bbox_scores' in instances: + result['bbox_score'] = instances.bbox_scores[i] + results.append(result) + + return results diff --git a/mmpose/testing/__init__.py b/mmpose/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5612dac6c66e3bf7c2bad86154ac62c9d5e9529a --- /dev/null +++ b/mmpose/testing/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ._utils import (get_coco_sample, get_config_file, get_packed_inputs, + get_pose_estimator_cfg, get_repo_dir) + +__all__ = [ + 'get_packed_inputs', 'get_coco_sample', 'get_config_file', + 'get_pose_estimator_cfg', 'get_repo_dir' +] diff --git a/mmpose/testing/_utils.py b/mmpose/testing/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1908129be858c5ff7c5b5b9a6e14e0f03858e53d --- /dev/null +++ b/mmpose/testing/_utils.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from copy import deepcopy +from typing import Optional + +import numpy as np +import torch +from mmengine.config import Config +from mmengine.dataset import pseudo_collate +from mmengine.structures import InstanceData, PixelData + +from mmpose.structures import MultilevelPixelData, PoseDataSample +from mmpose.structures.bbox import bbox_xyxy2cs + + +def get_coco_sample( + img_shape=(240, 320), + img_fill: Optional[int] = None, + num_instances=1, + with_bbox_cs=True, + with_img_mask=False, + random_keypoints_visible=False, + non_occlusion=False): + """Create a dummy data sample in COCO style.""" + rng = np.random.RandomState(0) + h, w = img_shape + if img_fill is None: + img = np.random.randint(0, 256, (h, w, 3), dtype=np.uint8) + else: + img = np.full((h, w, 3), img_fill, dtype=np.uint8) + + if non_occlusion: + bbox = _rand_bboxes(rng, num_instances, w / num_instances, h) + for i in range(num_instances): + bbox[i, 0::2] += w / num_instances * i + else: + bbox = _rand_bboxes(rng, num_instances, w, h) + + keypoints = _rand_keypoints(rng, bbox, 17) + if random_keypoints_visible: + keypoints_visible = np.random.randint(0, 2, (num_instances, + 17)).astype(np.float32) + else: + keypoints_visible = np.full((num_instances, 17), 1, dtype=np.float32) + + upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + lower_body_ids = [11, 12, 13, 14, 15, 16] + flip_pairs = [[2, 1], [1, 2], [4, 3], [3, 4], [6, 5], [5, 6], [8, 7], + [7, 8], [10, 9], [9, 10], [12, 11], [11, 12], [14, 13], + [13, 14], [16, 15], [15, 16]] + flip_indices = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + dataset_keypoint_weights = np.array([ + 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, + 1.5 + ]).astype(np.float32) + + data = { + 'img': img, + 'img_shape': img_shape, + 'ori_shape': img_shape, + 'bbox': bbox, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'upper_body_ids': upper_body_ids, + 'lower_body_ids': lower_body_ids, + 'flip_pairs': flip_pairs, + 'flip_indices': flip_indices, + 'dataset_keypoint_weights': dataset_keypoint_weights, + 'invalid_segs': [], + } + + if with_bbox_cs: + data['bbox_center'], data['bbox_scale'] = bbox_xyxy2cs(data['bbox']) + + if with_img_mask: + data['img_mask'] = np.random.randint(0, 2, (h, w), dtype=np.uint8) + + return data + + +def get_packed_inputs(batch_size=2, + num_instances=1, + num_keypoints=17, + num_levels=1, + img_shape=(256, 192), + input_size=(192, 256), + heatmap_size=(48, 64), + simcc_split_ratio=2.0, + with_heatmap=True, + with_reg_label=True, + with_simcc_label=True): + """Create a dummy batch of model inputs and data samples.""" + rng = np.random.RandomState(0) + + inputs_list = [] + for idx in range(batch_size): + inputs = dict() + + # input + h, w = img_shape + image = rng.randint(0, 255, size=(3, h, w), dtype=np.uint8) + inputs['inputs'] = torch.from_numpy(image) + + # meta + img_meta = { + 'id': idx, + 'img_id': idx, + 'img_path': '.png', + 'img_shape': img_shape, + 'input_size': input_size, + 'flip': False, + 'flip_direction': None, + 'flip_indices': list(range(num_keypoints)) + } + + np.random.shuffle(img_meta['flip_indices']) + data_sample = PoseDataSample(metainfo=img_meta) + + # gt_instance + gt_instances = InstanceData() + gt_instance_labels = InstanceData() + bboxes = _rand_bboxes(rng, num_instances, w, h) + bbox_centers, bbox_scales = bbox_xyxy2cs(bboxes) + + keypoints = _rand_keypoints(rng, bboxes, num_keypoints) + keypoints_visible = np.ones((num_instances, num_keypoints), + dtype=np.float32) + + # [N, K] -> [N, num_levels, K] + # keep the first dimension as the num_instances + if num_levels > 1: + keypoint_weights = np.tile(keypoints_visible[:, None], + (1, num_levels, 1)) + else: + keypoint_weights = keypoints_visible.copy() + + gt_instances.bboxes = bboxes + gt_instances.bbox_centers = bbox_centers + gt_instances.bbox_scales = bbox_scales + gt_instances.bbox_scores = np.ones((num_instances, ), dtype=np.float32) + gt_instances.keypoints = keypoints + gt_instances.keypoints_visible = keypoints_visible + + gt_instance_labels.keypoint_weights = torch.FloatTensor( + keypoint_weights) + + if with_reg_label: + gt_instance_labels.keypoint_labels = torch.FloatTensor(keypoints / + input_size) + + if with_simcc_label: + len_x = np.around(input_size[0] * simcc_split_ratio) + len_y = np.around(input_size[1] * simcc_split_ratio) + gt_instance_labels.keypoint_x_labels = torch.FloatTensor( + _rand_simcc_label(rng, num_instances, num_keypoints, len_x)) + gt_instance_labels.keypoint_y_labels = torch.FloatTensor( + _rand_simcc_label(rng, num_instances, num_keypoints, len_y)) + + # gt_fields + if with_heatmap: + if num_levels == 1: + gt_fields = PixelData() + # generate single-level heatmaps + W, H = heatmap_size + heatmaps = rng.rand(num_keypoints, H, W) + gt_fields.heatmaps = torch.FloatTensor(heatmaps) + else: + # generate multilevel heatmaps + heatmaps = [] + for _ in range(num_levels): + W, H = heatmap_size + heatmaps_ = rng.rand(num_keypoints, H, W) + heatmaps.append(torch.FloatTensor(heatmaps_)) + # [num_levels*K, H, W] + gt_fields = MultilevelPixelData() + gt_fields.heatmaps = heatmaps + data_sample.gt_fields = gt_fields + + data_sample.gt_instances = gt_instances + data_sample.gt_instance_labels = gt_instance_labels + + inputs['data_samples'] = data_sample + inputs_list.append(inputs) + + packed_inputs = pseudo_collate(inputs_list) + return packed_inputs + + +def _rand_keypoints(rng, bboxes, num_keypoints): + n = bboxes.shape[0] + relative_pos = rng.rand(n, num_keypoints, 2) + keypoints = relative_pos * bboxes[:, None, :2] + ( + 1 - relative_pos) * bboxes[:, None, 2:4] + + return keypoints + + +def _rand_simcc_label(rng, num_instances, num_keypoints, len_feats): + simcc_label = rng.rand(num_instances, num_keypoints, int(len_feats)) + return simcc_label + + +def _rand_bboxes(rng, num_instances, img_w, img_h): + cx, cy = rng.rand(num_instances, 2).T + bw, bh = 0.2 + 0.8 * rng.rand(num_instances, 2).T + + tl_x = ((cx * img_w) - (img_w * bw / 2)).clip(0, img_w) + tl_y = ((cy * img_h) - (img_h * bh / 2)).clip(0, img_h) + br_x = ((cx * img_w) + (img_w * bw / 2)).clip(0, img_w) + br_y = ((cy * img_h) + (img_h * bh / 2)).clip(0, img_h) + + bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T + return bboxes + + +def get_repo_dir(): + """Return the path of the MMPose repo directory.""" + try: + # Assume the function in invoked is the source mmpose repo + repo_dir = osp.dirname(osp.dirname(osp.dirname(__file__))) + except NameError: + # For IPython development when __file__ is not defined + import mmpose + repo_dir = osp.dirname(osp.dirname(mmpose.__file__)) + + return repo_dir + + +def get_config_file(fn: str): + """Return full path of a config file from the given relative path.""" + repo_dir = get_repo_dir() + if fn.startswith('configs'): + fn_config = osp.join(repo_dir, fn) + else: + fn_config = osp.join(repo_dir, 'configs', fn) + + if not osp.isfile(fn_config): + raise FileNotFoundError(f'Cannot find config file {fn_config}') + + return fn_config + + +def get_pose_estimator_cfg(fn: str): + """Load model config from a config file.""" + + fn_config = get_config_file(fn) + config = Config.fromfile(fn_config) + return deepcopy(config.model) diff --git a/mmpose/utils/__init__.py b/mmpose/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c48ca01cea586d0e35c2f6daf3138ab94fe4e613 --- /dev/null +++ b/mmpose/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .camera import SimpleCamera, SimpleCameraTorch +from .collect_env import collect_env +from .config_utils import adapt_mmdet_pipeline +from .logger import get_root_logger +from .setup_env import register_all_modules, setup_multi_processes +from .timer import StopWatch + +__all__ = [ + 'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes', + 'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch', + 'adapt_mmdet_pipeline' +] diff --git a/mmpose/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71aea1227c1c32fb2e53744620ea9bd2f2f8417b Binary files /dev/null and b/mmpose/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/camera.cpython-38.pyc b/mmpose/utils/__pycache__/camera.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e2183ee7296dd3f7e4c57edfb47ff9055f27334 Binary files /dev/null and b/mmpose/utils/__pycache__/camera.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/collect_env.cpython-38.pyc b/mmpose/utils/__pycache__/collect_env.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ced64105c27986680519500642887e6659699a8 Binary files /dev/null and b/mmpose/utils/__pycache__/collect_env.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/config_utils.cpython-38.pyc b/mmpose/utils/__pycache__/config_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3ca7e8d258ba1289c85923cd650aebadf596f9e Binary files /dev/null and b/mmpose/utils/__pycache__/config_utils.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/logger.cpython-38.pyc b/mmpose/utils/__pycache__/logger.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5657e594754b35bb2a067ea02a7b1d8f34311c6f Binary files /dev/null and b/mmpose/utils/__pycache__/logger.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/setup_env.cpython-38.pyc b/mmpose/utils/__pycache__/setup_env.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..351b99a3bf9dedd7be3b45aea0cb6788a084150e Binary files /dev/null and b/mmpose/utils/__pycache__/setup_env.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/tensor_utils.cpython-38.pyc b/mmpose/utils/__pycache__/tensor_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d077c70c227865b72085a6b2e3a4ae38fb9cbf2 Binary files /dev/null and b/mmpose/utils/__pycache__/tensor_utils.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/timer.cpython-38.pyc b/mmpose/utils/__pycache__/timer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95b8a96edd7cb726ae8d36d71cd57377336a60f1 Binary files /dev/null and b/mmpose/utils/__pycache__/timer.cpython-38.pyc differ diff --git a/mmpose/utils/__pycache__/typing.cpython-38.pyc b/mmpose/utils/__pycache__/typing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c114174ea50b5264da10b825ece386ae6f1e3ece Binary files /dev/null and b/mmpose/utils/__pycache__/typing.cpython-38.pyc differ diff --git a/mmpose/utils/camera.py b/mmpose/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..a7759d308f38fda99fcf56910b09251d24eccbed --- /dev/null +++ b/mmpose/utils/camera.py @@ -0,0 +1,280 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import numpy as np +import torch +from mmengine.registry import Registry + +CAMERAS = Registry('camera') + + +class SingleCameraBase(metaclass=ABCMeta): + """Base class for single camera model. + + Args: + param (dict): Camera parameters + + Methods: + world_to_camera: Project points from world coordinates to camera + coordinates + camera_to_world: Project points from camera coordinates to world + coordinates + camera_to_pixel: Project points from camera coordinates to pixel + coordinates + world_to_pixel: Project points from world coordinates to pixel + coordinates + """ + + @abstractmethod + def __init__(self, param): + """Load camera parameters and check validity.""" + + def world_to_camera(self, X): + """Project points from world coordinates to camera coordinates.""" + raise NotImplementedError + + def camera_to_world(self, X): + """Project points from camera coordinates to world coordinates.""" + raise NotImplementedError + + def camera_to_pixel(self, X): + """Project points from camera coordinates to pixel coordinates.""" + raise NotImplementedError + + def world_to_pixel(self, X): + """Project points from world coordinates to pixel coordinates.""" + _X = self.world_to_camera(X) + return self.camera_to_pixel(_X) + + +@CAMERAS.register_module() +class SimpleCamera(SingleCameraBase): + """Camera model to calculate coordinate transformation with given + intrinsic/extrinsic camera parameters. + + Note: + The keypoint coordinate should be an np.ndarray with a shape of + [...,J, C] where J is the keypoint number of an instance, and C is + the coordinate dimension. For example: + + [J, C]: shape of joint coordinates of a person with J joints. + [N, J, C]: shape of a batch of person joint coordinates. + [N, T, J, C]: shape of a batch of pose sequences. + + Args: + param (dict): camera parameters including: + - R: 3x3, camera rotation matrix (camera-to-world) + - T: 3x1, camera translation (camera-to-world) + - K: (optional) 2x3, camera intrinsic matrix + - k: (optional) nx1, camera radial distortion coefficients + - p: (optional) mx1, camera tangential distortion coefficients + - f: (optional) 2x1, camera focal length + - c: (optional) 2x1, camera center + if K is not provided, it will be calculated from f and c. + + Methods: + world_to_camera: Project points from world coordinates to camera + coordinates + camera_to_pixel: Project points from camera coordinates to pixel + coordinates + world_to_pixel: Project points from world coordinates to pixel + coordinates + """ + + def __init__(self, param): + + self.param = {} + # extrinsic param + R = np.array(param['R'], dtype=np.float32) + T = np.array(param['T'], dtype=np.float32) + assert R.shape == (3, 3) + assert T.shape == (3, 1) + # The camera matrices are transposed in advance because the joint + # coordinates are stored as row vectors. + self.param['R_c2w'] = R.T + self.param['T_c2w'] = T.T + self.param['R_w2c'] = R + self.param['T_w2c'] = -self.param['T_c2w'] @ self.param['R_w2c'] + + # intrinsic param + if 'K' in param: + K = np.array(param['K'], dtype=np.float32) + assert K.shape == (2, 3) + self.param['K'] = K.T + self.param['f'] = np.array([K[0, 0], K[1, 1]])[:, np.newaxis] + self.param['c'] = np.array([K[0, 2], K[1, 2]])[:, np.newaxis] + elif 'f' in param and 'c' in param: + f = np.array(param['f'], dtype=np.float32) + c = np.array(param['c'], dtype=np.float32) + assert f.shape == (2, 1) + assert c.shape == (2, 1) + self.param['K'] = np.concatenate((np.diagflat(f), c), axis=-1).T + self.param['f'] = f + self.param['c'] = c + else: + raise ValueError('Camera intrinsic parameters are missing. ' + 'Either "K" or "f"&"c" should be provided.') + + # distortion param + if 'k' in param and 'p' in param: + self.undistortion = True + self.param['k'] = np.array(param['k'], dtype=np.float32).flatten() + self.param['p'] = np.array(param['p'], dtype=np.float32).flatten() + assert self.param['k'].size in {3, 6} + assert self.param['p'].size == 2 + else: + self.undistortion = False + + def world_to_camera(self, X): + assert isinstance(X, np.ndarray) + assert X.ndim >= 2 and X.shape[-1] == 3 + return X @ self.param['R_w2c'] + self.param['T_w2c'] + + def camera_to_world(self, X): + assert isinstance(X, np.ndarray) + assert X.ndim >= 2 and X.shape[-1] == 3 + return X @ self.param['R_c2w'] + self.param['T_c2w'] + + def camera_to_pixel(self, X): + assert isinstance(X, np.ndarray) + assert X.ndim >= 2 and X.shape[-1] == 3 + + _X = X / X[..., 2:] + + if self.undistortion: + k = self.param['k'] + p = self.param['p'] + _X_2d = _X[..., :2] + r2 = (_X_2d**2).sum(-1) + radial = 1 + sum(ki * r2**(i + 1) for i, ki in enumerate(k[:3])) + if k.size == 6: + radial /= 1 + sum( + (ki * r2**(i + 1) for i, ki in enumerate(k[3:]))) + + tangential = 2 * (p[1] * _X[..., 0] + p[0] * _X[..., 1]) + + _X[..., :2] = _X_2d * (radial + tangential)[..., None] + np.outer( + r2, p[::-1]).reshape(_X_2d.shape) + return _X @ self.param['K'] + + def pixel_to_camera(self, X): + assert isinstance(X, np.ndarray) + assert X.ndim >= 2 and X.shape[-1] == 3 + _X = X.copy() + _X[:, :2] = (X[:, :2] - self.param['c'].T) / self.param['f'].T * X[:, + [2]] + return _X + + +@CAMERAS.register_module() +class SimpleCameraTorch(SingleCameraBase): + """Camera model to calculate coordinate transformation with given + intrinsic/extrinsic camera parameters. + + Notes: + The keypoint coordinate should be an np.ndarray with a shape of + [...,J, C] where J is the keypoint number of an instance, and C is + the coordinate dimension. For example: + + [J, C]: shape of joint coordinates of a person with J joints. + [N, J, C]: shape of a batch of person joint coordinates. + [N, T, J, C]: shape of a batch of pose sequences. + + Args: + param (dict): camera parameters including: + - R: 3x3, camera rotation matrix (camera-to-world) + - T: 3x1, camera translation (camera-to-world) + - K: (optional) 2x3, camera intrinsic matrix + - k: (optional) nx1, camera radial distortion coefficients + - p: (optional) mx1, camera tangential distortion coefficients + - f: (optional) 2x1, camera focal length + - c: (optional) 2x1, camera center + if K is not provided, it will be calculated from f and c. + + Methods: + world_to_camera: Project points from world coordinates to camera + coordinates + camera_to_pixel: Project points from camera coordinates to pixel + coordinates + world_to_pixel: Project points from world coordinates to pixel + coordinates + """ + + def __init__(self, param, device): + + self.param = {} + # extrinsic param + R = torch.tensor(param['R'], device=device) + T = torch.tensor(param['T'], device=device) + + assert R.shape == (3, 3) + assert T.shape == (3, 1) + # The camera matrices are transposed in advance because the joint + # coordinates are stored as row vectors. + self.param['R_c2w'] = R.T + self.param['T_c2w'] = T.T + self.param['R_w2c'] = R + self.param['T_w2c'] = -self.param['T_c2w'] @ self.param['R_w2c'] + + # intrinsic param + if 'K' in param: + K = torch.tensor(param['K'], device=device) + assert K.shape == (2, 3) + self.param['K'] = K.T + self.param['f'] = torch.tensor([[K[0, 0]], [K[1, 1]]], + device=device) + self.param['c'] = torch.tensor([[K[0, 2]], [K[1, 2]]], + device=device) + elif 'f' in param and 'c' in param: + f = torch.tensor(param['f'], device=device) + c = torch.tensor(param['c'], device=device) + assert f.shape == (2, 1) + assert c.shape == (2, 1) + self.param['K'] = torch.cat([torch.diagflat(f), c], dim=-1).T + self.param['f'] = f + self.param['c'] = c + else: + raise ValueError('Camera intrinsic parameters are missing. ' + 'Either "K" or "f"&"c" should be provided.') + + # distortion param + if 'k' in param and 'p' in param: + self.undistortion = True + self.param['k'] = torch.tensor(param['k'], device=device).view(-1) + self.param['p'] = torch.tensor(param['p'], device=device).view(-1) + assert len(self.param['k']) in {3, 6} + assert len(self.param['p']) == 2 + else: + self.undistortion = False + + def world_to_camera(self, X): + assert isinstance(X, torch.Tensor) + assert X.ndim >= 2 and X.shape[-1] == 3 + return X @ self.param['R_w2c'] + self.param['T_w2c'] + + def camera_to_world(self, X): + assert isinstance(X, torch.Tensor) + assert X.ndim >= 2 and X.shape[-1] == 3 + return X @ self.param['R_c2w'] + self.param['T_c2w'] + + def camera_to_pixel(self, X): + assert isinstance(X, torch.Tensor) + assert X.ndim >= 2 and X.shape[-1] == 3 + + _X = X / X[..., 2:] + + if self.undistortion: + k = self.param['k'] + p = self.param['p'] + _X_2d = _X[..., :2] + r2 = (_X_2d**2).sum(-1) + radial = 1 + sum(ki * r2**(i + 1) for i, ki in enumerate(k[:3])) + if k.size == 6: + radial /= 1 + sum( + (ki * r2**(i + 1) for i, ki in enumerate(k[3:]))) + + tangential = 2 * (p[1] * _X[..., 0] + p[0] * _X[..., 1]) + + _X[..., :2] = _X_2d * (radial + tangential)[..., None] + torch.ger( + r2, p.flip([0])).reshape(_X_2d.shape) + return _X @ self.param['K'] diff --git a/mmpose/utils/collect_env.py b/mmpose/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fb5f35e10fe6535b49b7eb7def1459b28835e3 --- /dev/null +++ b/mmpose/utils/collect_env.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmpose + + +def collect_env(): + env_info = collect_base_env() + env_info['MMPose'] = (mmpose.__version__ + '+' + get_git_hash(digits=7)) + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/mmpose/utils/config_utils.py b/mmpose/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f54d2ef24093a77933dbf8026465e3cdaf5e839 --- /dev/null +++ b/mmpose/utils/config_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.utils.typing import ConfigDict + + +def adapt_mmdet_pipeline(cfg: ConfigDict) -> ConfigDict: + """Converts pipeline types in MMDetection's test dataloader to use the + 'mmdet' namespace. + + Args: + cfg (ConfigDict): Configuration dictionary for MMDetection. + + Returns: + ConfigDict: Configuration dictionary with updated pipeline types. + """ + # use lazy import to avoid hard dependence on mmdet + from mmdet.datasets import transforms + + if 'test_dataloader' not in cfg: + return cfg + + pipeline = cfg.test_dataloader.dataset.pipeline + for trans in pipeline: + if trans['type'] in dir(transforms): + trans['type'] = 'mmdet.' + trans['type'] + + return cfg diff --git a/mmpose/utils/hooks.py b/mmpose/utils/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..b68940f2b7a8a618916ea5aab331e3ce45ba98e7 --- /dev/null +++ b/mmpose/utils/hooks.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + + +class OutputHook: + + def __init__(self, module, outputs=None, as_tensor=False): + self.outputs = outputs + self.as_tensor = as_tensor + self.layer_outputs = {} + self.register(module) + + def register(self, module): + + def hook_wrapper(name): + + def hook(model, input, output): + if self.as_tensor: + self.layer_outputs[name] = output + else: + if isinstance(output, list): + self.layer_outputs[name] = [ + out.detach().cpu().numpy() for out in output + ] + else: + self.layer_outputs[name] = output.detach().cpu().numpy( + ) + + return hook + + self.handles = [] + if isinstance(self.outputs, (list, tuple)): + for name in self.outputs: + try: + layer = rgetattr(module, name) + h = layer.register_forward_hook(hook_wrapper(name)) + except ModuleNotFoundError as module_not_found: + raise ModuleNotFoundError( + f'Module {name} not found') from module_not_found + self.handles.append(h) + + def remove(self): + for h in self.handles: + h.remove() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.remove() + + +# using wonder's beautiful simplification: +# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects +def rgetattr(obj, attr, *args): + + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) diff --git a/mmpose/utils/logger.py b/mmpose/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f67e56efeb998cf966e3729c90791b4a70f2bb84 --- /dev/null +++ b/mmpose/utils/logger.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +from mmengine.logging import MMLogger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Use `MMLogger` class in mmengine to get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmpose". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + return MMLogger('MMLogger', __name__.split('.')[0], log_file, log_level) diff --git a/mmpose/utils/setup_env.py b/mmpose/utils/setup_env.py new file mode 100644 index 0000000000000000000000000000000000000000..ff299539ef8cc83a17a24e41498c01ff4f26667f --- /dev/null +++ b/mmpose/utils/setup_env.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import os +import platform +import warnings + +import cv2 +import torch.multiprocessing as mp +from mmengine import DefaultScope + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != 'Windows': + mp_start_method = cfg.get('mp_start_method', 'fork') + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f'Multi-processing start method `{mp_start_method}` is ' + f'different from the previous setting `{current_method}`.' + f'It will be force set to `{mp_start_method}`. You can change ' + f'this behavior by changing `mp_start_method` in your config.') + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get('opencv_num_threads', 0) + cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: + omp_num_threads = 1 + warnings.warn( + f'Setting OMP_NUM_THREADS environment variable for each process ' + f'to be {omp_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + + # setup MKL threads + if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: + mkl_num_threads = 1 + warnings.warn( + f'Setting MKL_NUM_THREADS environment variable for each process ' + f'to be {mkl_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmpose into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmpose default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmpose`, and all registries will build modules from mmpose's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + + import mmpose.codecs # noqa: F401, F403 + import mmpose.datasets # noqa: F401,F403 + import mmpose.engine # noqa: F401,F403 + import mmpose.evaluation # noqa: F401,F403 + import mmpose.models # noqa: F401,F403 + import mmpose.visualization # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmpose') + if never_created: + DefaultScope.get_instance('mmpose', scope_name='mmpose') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmpose': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmpose", ' + '`register_all_modules` will force the current' + 'default scope to be "mmpose". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmpose-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmpose') diff --git a/mmpose/utils/tensor_utils.py b/mmpose/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1be73f899178ead4dd4ade9f621aedb07bec4258 --- /dev/null +++ b/mmpose/utils/tensor_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.utils import is_seq_of +from torch import Tensor + + +def to_numpy(x: Union[Tensor, Sequence[Tensor]], + return_device: bool = False, + unzip: bool = False) -> Union[np.ndarray, tuple]: + """Convert torch tensor to numpy.ndarray. + + Args: + x (Tensor | Sequence[Tensor]): A single tensor or a sequence of + tensors + return_device (bool): Whether return the tensor device. Defaults to + ``False`` + unzip (bool): Whether unzip the input sequence. Defaults to ``False`` + + Returns: + np.ndarray | tuple: If ``return_device`` is ``True``, return a tuple + of converted numpy array(s) and the device indicator; otherwise only + return the numpy array(s) + """ + + if isinstance(x, Tensor): + arrays = x.detach().cpu().numpy() + device = x.device + elif is_seq_of(x, Tensor): + if unzip: + # convert (A, B) -> [(A[0], B[0]), (A[1], B[1]), ...] + arrays = [ + tuple(to_numpy(_x[None, :]) for _x in _each) + for _each in zip(*x) + ] + else: + arrays = [to_numpy(_x) for _x in x] + + device = x[0].device + + else: + raise ValueError(f'Invalid input type {type(x)}') + + if return_device: + return arrays, device + else: + return arrays + + +def to_tensor(x: Union[np.ndarray, Sequence[np.ndarray]], + device: Optional[Any] = None) -> Union[Tensor, Sequence[Tensor]]: + """Convert numpy.ndarray to torch tensor. + + Args: + x (np.ndarray | Sequence[np.ndarray]): A single np.ndarray or a + sequence of tensors + tensor (Any, optional): The device indicator. Defaults to ``None`` + + Returns: + tuple: + - Tensor | Sequence[Tensor]: The converted Tensor or Tensor sequence + """ + if isinstance(x, np.ndarray): + return torch.tensor(x, device=device) + elif is_seq_of(x, np.ndarray): + return [to_tensor(_x, device=device) for _x in x] + else: + raise ValueError(f'Invalid input type {type(x)}') diff --git a/mmpose/utils/timer.py b/mmpose/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..c219c04069d239605a7854b06a370876dbe8fd58 --- /dev/null +++ b/mmpose/utils/timer.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from contextlib import contextmanager +from functools import partial + +import numpy as np +from mmengine import Timer + + +class RunningAverage(): + r"""A helper class to calculate running average in a sliding window. + + Args: + window (int): The size of the sliding window. + """ + + def __init__(self, window: int = 1): + self.window = window + self._data = [] + + def update(self, value): + """Update a new data sample.""" + self._data.append(value) + self._data = self._data[-self.window:] + + def average(self): + """Get the average value of current window.""" + return np.mean(self._data) + + +class StopWatch: + r"""A helper class to measure FPS and detailed time consuming of each phase + in a video processing loop or similar scenarios. + + Args: + window (int): The sliding window size to calculate the running average + of the time consuming. + + Example: + >>> from mmpose.utils import StopWatch + >>> import time + >>> stop_watch = StopWatch(window=10) + >>> with stop_watch.timeit('total'): + >>> time.sleep(0.1) + >>> # 'timeit' support nested use + >>> with stop_watch.timeit('phase1'): + >>> time.sleep(0.1) + >>> with stop_watch.timeit('phase2'): + >>> time.sleep(0.2) + >>> time.sleep(0.2) + >>> report = stop_watch.report() + """ + + def __init__(self, window=1): + self.window = window + self._record = defaultdict(partial(RunningAverage, window=self.window)) + self._timer_stack = [] + + @contextmanager + def timeit(self, timer_name='_FPS_'): + """Timing a code snippet with an assigned name. + + Args: + timer_name (str): The unique name of the interested code snippet to + handle multiple timers and generate reports. Note that '_FPS_' + is a special key that the measurement will be in `fps` instead + of `millisecond`. Also see `report` and `report_strings`. + Default: '_FPS_'. + Note: + This function should always be used in a `with` statement, as shown + in the example. + """ + self._timer_stack.append((timer_name, Timer())) + try: + yield + finally: + timer_name, timer = self._timer_stack.pop() + self._record[timer_name].update(timer.since_start()) + + def report(self, key=None): + """Report timing information. + + Returns: + dict: The key is the timer name and the value is the \ + corresponding average time consuming. + """ + result = { + name: r.average() * 1000. + for name, r in self._record.items() + } + + if '_FPS_' in result: + result['_FPS_'] = 1000. / result.pop('_FPS_') + + if key is None: + return result + return result[key] + + def report_strings(self): + """Report timing information in texture strings. + + Returns: + list(str): Each element is the information string of a timed \ + event, in format of '{timer_name}: {time_in_ms}'. \ + Specially, if timer_name is '_FPS_', the result will \ + be converted to fps. + """ + result = self.report() + strings = [] + if '_FPS_' in result: + strings.append(f'FPS: {result["_FPS_"]:>5.1f}') + strings += [f'{name}: {val:>3.0f}' for name, val in result.items()] + return strings + + def reset(self): + self._record = defaultdict(list) + self._active_timer_stack = [] diff --git a/mmpose/utils/typing.py b/mmpose/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..557891b3b92e657de43eb50d4b5fbce7d369e7ee --- /dev/null +++ b/mmpose/utils/typing.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmpose.structures import PoseDataSample + +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] +# Type hint of one or more config data +MultiConfig = Union[ConfigType, List[ConfigType]] +OptMultiConfig = Optional[MultiConfig] +# Type hint of data samples +SampleList = List[PoseDataSample] +OptSampleList = Optional[SampleList] +InstanceList = List[InstanceData] +PixelDataList = List[PixelData] +Predictions = Union[InstanceList, Tuple[InstanceList, PixelDataList]] +# Type hint of model outputs +ForwardResults = Union[Dict[str, Tensor], List[PoseDataSample], Tuple[Tensor], + Tensor] +# Type hint of features +# - Tuple[Tensor]: multi-level features extracted by the network +# - List[Tuple[Tensor]]: multiple feature pyramids for TTA +# - List[List[Tuple[Tensor]]]: multi-scale feature pyramids +Features = Union[Tuple[Tensor], List[Tuple[Tensor]], List[List[Tuple[Tensor]]]] diff --git a/mmpose/version.py b/mmpose/version.py new file mode 100644 index 0000000000000000000000000000000000000000..73312cc28dbe08a076c725e0b9a6ea8579db9d1c --- /dev/null +++ b/mmpose/version.py @@ -0,0 +1,31 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '1.0.0' +short_version = __version__ + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + elif x.find('b') != -1: + patch_version = x.split('b') + version_info.append(int(patch_version[0])) + version_info.append(f'b{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/mmpose/visualization/__init__.py b/mmpose/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..357d40a707bd5e87d65fc4236233658dd1aab18d --- /dev/null +++ b/mmpose/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import PoseLocalVisualizer + +__all__ = ['PoseLocalVisualizer'] diff --git a/mmpose/visualization/__pycache__/__init__.cpython-38.pyc b/mmpose/visualization/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479042786d87f1d1c5e92f2b0f49d44d03c63d7e Binary files /dev/null and b/mmpose/visualization/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpose/visualization/__pycache__/local_visualizer.cpython-38.pyc b/mmpose/visualization/__pycache__/local_visualizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63fbedc3689d09829c2f9471b5768ce68a97c9ae Binary files /dev/null and b/mmpose/visualization/__pycache__/local_visualizer.cpython-38.pyc differ diff --git a/mmpose/visualization/__pycache__/opencv_backend_visualizer.cpython-38.pyc b/mmpose/visualization/__pycache__/opencv_backend_visualizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..399cb790289114ac8431e78c36c6989f53f07f8e Binary files /dev/null and b/mmpose/visualization/__pycache__/opencv_backend_visualizer.cpython-38.pyc differ diff --git a/mmpose/visualization/__pycache__/simcc_vis.cpython-38.pyc b/mmpose/visualization/__pycache__/simcc_vis.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7b3e5367f695d1d2bfd6e2253e41ef6aa623b1e Binary files /dev/null and b/mmpose/visualization/__pycache__/simcc_vis.cpython-38.pyc differ diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..205993c006eb5349c0dc371e7d4124b19050cf2f --- /dev/null +++ b/mmpose/visualization/local_visualizer.py @@ -0,0 +1,583 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.dist import master_only +from mmengine.structures import InstanceData, PixelData + +from mmpose.datasets.datasets.utils import parse_pose_metainfo +from mmpose.registry import VISUALIZERS +from mmpose.structures import PoseDataSample +from .opencv_backend_visualizer import OpencvBackendVisualizer +from .simcc_vis import SimCCVisualizer + + +def _get_adaptive_scales(areas: np.ndarray, + min_area: int = 800, + max_area: int = 30000) -> np.ndarray: + """Get adaptive scales according to areas. + + The scale range is [0.5, 1.0]. When the area is less than + ``min_area``, the scale is 0.5 while the area is larger than + ``max_area``, the scale is 1.0. + + Args: + areas (ndarray): The areas of bboxes or masks with the + shape of (n, ). + min_area (int): Lower bound areas for adaptive scales. + Defaults to 800. + max_area (int): Upper bound areas for adaptive scales. + Defaults to 30000. + + Returns: + ndarray: The adaotive scales with the shape of (n, ). + """ + scales = 0.5 + (areas - min_area) / (max_area - min_area) + scales = np.clip(scales, 0.5, 1.0) + return scales + + +@VISUALIZERS.register_module() +class PoseLocalVisualizer(OpencvBackendVisualizer): + """MMPose Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to ``None`` + vis_backends (list, optional): Visual backend config list. Defaults to + ``None`` + save_dir (str, optional): Save file dir for all storage backends. + If it is ``None``, the backend storage will not save any data. + Defaults to ``None`` + bbox_color (str, tuple(int), optional): Color of bbox lines. + The tuple of color should be in BGR order. Defaults to ``'green'`` + kpt_color (str, tuple(tuple(int)), optional): Color of keypoints. + The tuple of color should be in BGR order. Defaults to ``'red'`` + link_color (str, tuple(tuple(int)), optional): Color of skeleton. + The tuple of color should be in BGR order. Defaults to ``None`` + line_width (int, float): The width of lines. Defaults to 1 + radius (int, float): The radius of keypoints. Defaults to 4 + show_keypoint_weight (bool): Whether to adjust the transparency + of keypoints according to their score. Defaults to ``False`` + alpha (int, float): The transparency of bboxes. Defaults to ``0.8`` + + Examples: + >>> import numpy as np + >>> from mmengine.structures import InstanceData + >>> from mmpose.structures import PoseDataSample + >>> from mmpose.visualization import PoseLocalVisualizer + + >>> pose_local_visualizer = PoseLocalVisualizer(radius=1) + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_instances = InstanceData() + >>> gt_instances.keypoints = np.array([[[1, 1], [2, 2], [4, 4], + ... [8, 8]]]) + >>> gt_pose_data_sample = PoseDataSample() + >>> gt_pose_data_sample.gt_instances = gt_instances + >>> dataset_meta = {'skeleton_links': [[0, 1], [1, 2], [2, 3]]} + >>> pose_local_visualizer.set_dataset_meta(dataset_meta) + >>> pose_local_visualizer.add_datasample('image', image, + ... gt_pose_data_sample) + >>> pose_local_visualizer.add_datasample( + ... 'image', image, gt_pose_data_sample, + ... out_file='out_file.jpg') + >>> pose_local_visualizer.add_datasample( + ... 'image', image, gt_pose_data_sample, + ... show=True) + >>> pred_instances = InstanceData() + >>> pred_instances.keypoints = np.array([[[1, 1], [2, 2], [4, 4], + ... [8, 8]]]) + >>> pred_instances.score = np.array([0.8, 1, 0.9, 1]) + >>> pred_pose_data_sample = PoseDataSample() + >>> pred_pose_data_sample.pred_instances = pred_instances + >>> pose_local_visualizer.add_datasample('image', image, + ... gt_pose_data_sample, + ... pred_pose_data_sample) + """ + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + bbox_color: Optional[Union[str, Tuple[int]]] = 'green', + kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = 'red', + link_color: Optional[Union[str, Tuple[Tuple[int]]]] = None, + text_color: Optional[Union[str, + Tuple[int]]] = (255, 255, 255), + skeleton: Optional[Union[List, Tuple]] = None, + line_width: Union[int, float] = 1, + radius: Union[int, float] = 3, + show_keypoint_weight: bool = False, + backend: str = 'opencv', + alpha: float = 0.8): + super().__init__( + name=name, + image=image, + vis_backends=vis_backends, + save_dir=save_dir, + backend=backend) + + self.bbox_color = bbox_color + self.kpt_color = kpt_color + self.link_color = link_color + self.line_width = line_width + self.text_color = text_color + self.skeleton = skeleton + self.radius = radius + self.alpha = alpha + self.show_keypoint_weight = show_keypoint_weight + # Set default value. When calling + # `PoseLocalVisualizer().set_dataset_meta(xxx)`, + # it will override the default value. + self.dataset_meta = {} + + def set_dataset_meta(self, + dataset_meta: Dict, + skeleton_style: str = 'mmpose'): + """Assign dataset_meta to the visualizer. The default visualization + settings will be overridden. + + Args: + dataset_meta (dict): meta information of dataset. + """ + if dataset_meta.get( + 'dataset_name') == 'coco' and skeleton_style == 'openpose': + dataset_meta = parse_pose_metainfo( + dict(from_file='configs/_base_/datasets/coco_openpose.py')) + + if isinstance(dataset_meta, dict): + self.dataset_meta = dataset_meta.copy() + self.bbox_color = dataset_meta.get('bbox_color', self.bbox_color) + self.kpt_color = dataset_meta.get('keypoint_colors', + self.kpt_color) + self.link_color = dataset_meta.get('skeleton_link_colors', + self.link_color) + self.skeleton = dataset_meta.get('skeleton_links', self.skeleton) + # sometimes self.dataset_meta is manually set, which might be None. + # it should be converted to a dict at these times + if self.dataset_meta is None: + self.dataset_meta = {} + + def _draw_instances_bbox(self, image: np.ndarray, + instances: InstanceData) -> np.ndarray: + """Draw bounding boxes and corresponding labels of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + self.set_image(image) + + if 'bboxes' in instances: + bboxes = instances.bboxes + self.draw_bboxes( + bboxes, + edge_colors=self.bbox_color, + alpha=self.alpha, + line_widths=self.line_width) + else: + return self.get_image() + + if 'labels' in instances and self.text_color is not None: + classes = self.dataset_meta.get('classes', None) + labels = instances.labels + + positions = bboxes[:, :2] + areas = (bboxes[:, 3] - bboxes[:, 1]) * ( + bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas) + + for i, (pos, label) in enumerate(zip(positions, labels)): + label_text = classes[ + label] if classes is not None else f'class {label}' + + if isinstance(self.bbox_color, + tuple) and max(self.bbox_color) > 1: + facecolor = [c / 255.0 for c in self.bbox_color] + else: + facecolor = self.bbox_color + + self.draw_texts( + label_text, + pos, + colors=self.text_color, + font_sizes=int(13 * scales[i]), + vertical_alignments='bottom', + bboxes=[{ + 'facecolor': facecolor, + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + return self.get_image() + + def _draw_instances_kpts(self, + image: np.ndarray, + instances: InstanceData, + kpt_thr: float = 0.3, + show_kpt_idx: bool = False, + skeleton_style: str = 'mmpose'): + """Draw keypoints and skeletons (optional) of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + kpt_thr (float, optional): Minimum threshold of keypoints + to be shown. Default: 0.3. + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` + skeleton_style (str): Skeleton style selection. Defaults to + ``'mmpose'`` + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + + self.set_image(image) + img_h, img_w, _ = image.shape + + if 'keypoints' in instances: + keypoints = instances.get('transformed_keypoints', + instances.keypoints) + + if 'keypoint_scores' in instances: + scores = instances.keypoint_scores + else: + scores = np.ones(keypoints.shape[:-1]) + + if 'keypoints_visible' in instances: + keypoints_visible = instances.keypoints_visible + else: + keypoints_visible = np.ones(keypoints.shape[:-1]) + + if skeleton_style == 'openpose': + keypoints_info = np.concatenate( + (keypoints, scores[..., None], keypoints_visible[..., + None]), + axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > kpt_thr, + keypoints_info[:, 6, 2:4] > kpt_thr).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores, keypoints_visible = keypoints_info[ + ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + + for kpts, score, visible in zip(keypoints, scores, + keypoints_visible): + kpts = np.array(kpts, copy=False) + + if self.kpt_color is None or isinstance(self.kpt_color, str): + kpt_color = [self.kpt_color] * len(kpts) + elif len(self.kpt_color) == len(kpts): + kpt_color = self.kpt_color + else: + raise ValueError( + f'the length of kpt_color ' + f'({len(self.kpt_color)}) does not matches ' + f'that of keypoints ({len(kpts)})') + + # draw links + if self.skeleton is not None and self.link_color is not None: + if self.link_color is None or isinstance( + self.link_color, str): + link_color = [self.link_color] * len(self.skeleton) + elif len(self.link_color) == len(self.skeleton): + link_color = self.link_color + else: + raise ValueError( + f'the length of link_color ' + f'({len(self.link_color)}) does not matches ' + f'that of skeleton ({len(self.skeleton)})') + + for sk_id, sk in enumerate(self.skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + if not (visible[sk[0]] and visible[sk[1]]): + continue + + if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 + or pos1[1] >= img_h or pos2[0] <= 0 + or pos2[0] >= img_w or pos2[1] <= 0 + or pos2[1] >= img_h or score[sk[0]] < kpt_thr + or score[sk[1]] < kpt_thr + or link_color[sk_id] is None): + # skip the link that should not be drawn + continue + X = np.array((pos1[0], pos2[0])) + Y = np.array((pos1[1], pos2[1])) + color = link_color[sk_id] + if not isinstance(color, str): + color = tuple(int(c) for c in color) + transparency = self.alpha + if self.show_keypoint_weight: + transparency *= max( + 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]]))) + + if skeleton_style == 'openpose': + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees( + math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygons = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), int(angle), + 0, 360, 1) + + self.draw_polygons( + polygons, + edge_colors=color, + face_colors=color, + alpha=transparency) + + else: + self.draw_lines( + X, Y, color, line_widths=self.line_width) + + # draw each point on image + for kid, kpt in enumerate(kpts): + if score[kid] < kpt_thr or not visible[ + kid] or kpt_color[kid] is None: + # skip the point that should not be drawn + continue + + color = kpt_color[kid] + if not isinstance(color, str): + color = tuple(int(c) for c in color) + transparency = self.alpha + if self.show_keypoint_weight: + transparency *= max(0, min(1, score[kid])) + self.draw_circles( + kpt, + radius=np.array([self.radius]), + face_colors=color, + edge_colors=color, + alpha=transparency, + line_widths=self.radius) + if show_kpt_idx: + kpt[0] += self.radius + kpt[1] -= self.radius + self.draw_texts( + str(kid), + kpt, + colors=color, + font_sizes=self.radius * 3, + vertical_alignments='bottom', + horizontal_alignments='center') + + return self.get_image() + + def _draw_instance_heatmap( + self, + fields: PixelData, + overlaid_image: Optional[np.ndarray] = None, + ): + """Draw heatmaps of GT or prediction. + + Args: + fields (:obj:`PixelData`): Data structure for + pixel-level annotations or predictions. + overlaid_image (np.ndarray): The image to draw. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + if 'heatmaps' not in fields: + return None + heatmaps = fields.heatmaps + if isinstance(heatmaps, np.ndarray): + heatmaps = torch.from_numpy(heatmaps) + if heatmaps.dim() == 3: + heatmaps, _ = heatmaps.max(dim=0) + heatmaps = heatmaps.unsqueeze(0) + out_image = self.draw_featmap(heatmaps, overlaid_image) + return out_image + + def _draw_instance_xy_heatmap( + self, + fields: PixelData, + overlaid_image: Optional[np.ndarray] = None, + n: int = 20, + ): + """Draw heatmaps of GT or prediction. + + Args: + fields (:obj:`PixelData`): Data structure for + pixel-level annotations or predictions. + overlaid_image (np.ndarray): The image to draw. + n (int): Number of keypoint, up to 20. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + if 'heatmaps' not in fields: + return None + heatmaps = fields.heatmaps + _, h, w = heatmaps.shape + if isinstance(heatmaps, np.ndarray): + heatmaps = torch.from_numpy(heatmaps) + out_image = SimCCVisualizer().draw_instance_xy_heatmap( + heatmaps, overlaid_image, n) + out_image = cv2.resize(out_image[:, :, ::-1], (w, h)) + return out_image + + @master_only + def add_datasample(self, + name: str, + image: np.ndarray, + data_sample: PoseDataSample, + draw_gt: bool = True, + draw_pred: bool = True, + draw_heatmap: bool = False, + draw_bbox: bool = False, + show_kpt_idx: bool = False, + skeleton_style: str = 'mmpose', + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + kpt_thr: float = 0.3, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. t is usually used when the display + is not available. + + Args: + name (str): The image identifier + image (np.ndarray): The image to draw + data_sample (:obj:`PoseDataSample`, optional): The data sample + to visualize + draw_gt (bool): Whether to draw GT PoseDataSample. Default to + ``True`` + draw_pred (bool): Whether to draw Prediction PoseDataSample. + Defaults to ``True`` + draw_bbox (bool): Whether to draw bounding boxes. Default to + ``False`` + draw_heatmap (bool): Whether to draw heatmaps. Defaults to + ``False`` + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` + skeleton_style (str): Skeleton style selection. Defaults to + ``'mmpose'`` + show (bool): Whether to display the drawn image. Default to + ``False`` + wait_time (float): The interval of show (s). Defaults to 0 + out_file (str): Path to output file. Defaults to ``None`` + kpt_thr (float, optional): Minimum threshold of keypoints + to be shown. Default: 0.3. + step (int): Global step value to record. Defaults to 0 + """ + + gt_img_data = None + pred_img_data = None + + if draw_gt: + gt_img_data = image.copy() + gt_img_heatmap = None + + # draw bboxes & keypoints + if 'gt_instances' in data_sample: + gt_img_data = self._draw_instances_kpts( + gt_img_data, data_sample.gt_instances, kpt_thr, + show_kpt_idx, skeleton_style) + if draw_bbox: + gt_img_data = self._draw_instances_bbox( + gt_img_data, data_sample.gt_instances) + + # draw heatmaps + if 'gt_fields' in data_sample and draw_heatmap: + gt_img_heatmap = self._draw_instance_heatmap( + data_sample.gt_fields, image) + if gt_img_heatmap is not None: + gt_img_data = np.concatenate((gt_img_data, gt_img_heatmap), + axis=0) + + if draw_pred: + pred_img_data = image.copy() + pred_img_heatmap = None + + # draw bboxes & keypoints + if 'pred_instances' in data_sample: + pred_img_data = self._draw_instances_kpts( + pred_img_data, data_sample.pred_instances, kpt_thr, + show_kpt_idx, skeleton_style) + if draw_bbox: + pred_img_data = self._draw_instances_bbox( + pred_img_data, data_sample.pred_instances) + + # draw heatmaps + if 'pred_fields' in data_sample and draw_heatmap: + if 'keypoint_x_labels' in data_sample.pred_instances: + pred_img_heatmap = self._draw_instance_xy_heatmap( + data_sample.pred_fields, image) + else: + pred_img_heatmap = self._draw_instance_heatmap( + data_sample.pred_fields, image) + if pred_img_heatmap is not None: + pred_img_data = np.concatenate( + (pred_img_data, pred_img_heatmap), axis=0) + + # merge visualization results + if gt_img_data is not None and pred_img_data is not None: + if gt_img_heatmap is None and pred_img_heatmap is not None: + gt_img_data = np.concatenate((gt_img_data, image), axis=0) + elif gt_img_heatmap is not None and pred_img_heatmap is None: + pred_img_data = np.concatenate((pred_img_data, image), axis=0) + + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + # It is convenient for users to obtain the drawn image. + # For example, the user wants to obtain the drawn image and + # save it as a video during video inference. + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + # save drawn_img to backends + self.add_image(name, drawn_img, step) + + return self.get_image() diff --git a/mmpose/visualization/opencv_backend_visualizer.py b/mmpose/visualization/opencv_backend_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..66a7731c76e19a8e7f719a24748b64f316326084 --- /dev/null +++ b/mmpose/visualization/opencv_backend_visualizer.py @@ -0,0 +1,444 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.dist import master_only +from mmengine.visualization import Visualizer + + +class OpencvBackendVisualizer(Visualizer): + """Base visualizer with opencv backend support. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. + backend (str): Backend used to draw elements on the image and display + the image. Defaults to 'matplotlib'. + """ + + def __init__(self, + name='visualizer', + backend: str = 'matplotlib', + *args, + **kwargs): + super().__init__(name, *args, **kwargs) + assert backend in ('opencv', 'matplotlib'), f'the argument ' \ + f'\'backend\' must be either \'opencv\' or \'matplotlib\', ' \ + f'but got \'{backend}\'.' + self.backend = backend + + @master_only + def set_image(self, image: np.ndarray) -> None: + """Set the image to draw. + + Args: + image (np.ndarray): The image to draw. + backend (str): The backend to save the image. + """ + assert image is not None + image = image.astype('uint8') + self._image = image + self.width, self.height = image.shape[1], image.shape[0] + self._default_font_size = max( + np.sqrt(self.height * self.width) // 90, 10) + + if self.backend == 'matplotlib': + # add a small 1e-2 to avoid precision lost due to matplotlib's + # truncation (https://github.com/matplotlib/matplotlib/issues/15363) # noqa + self.fig_save.set_size_inches( # type: ignore + (self.width + 1e-2) / self.dpi, + (self.height + 1e-2) / self.dpi) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + self.ax_save.cla() + self.ax_save.axis(False) + self.ax_save.imshow( + image, + extent=(0, self.width, self.height, 0), + interpolation='none') + + @master_only + def get_image(self) -> np.ndarray: + """Get the drawn image. The format is RGB. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + assert self._image is not None, 'Please set image using `set_image`' + if self.backend == 'matplotlib': + return super().get_image() + else: + return self._image + + @master_only + def draw_circles(self, + center: Union[np.ndarray, torch.Tensor], + radius: Union[np.ndarray, torch.Tensor], + face_colors: Union[str, tuple, List[str], + List[tuple]] = 'none', + **kwargs) -> 'Visualizer': + """Draw single or multiple circles. + + Args: + center (Union[np.ndarray, torch.Tensor]): The x coordinate of + each line' start and end points. + radius (Union[np.ndarray, torch.Tensor]): The y coordinate of + each line' start and end points. + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of circles. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Defaults to None. + alpha (Union[int, float]): The transparency of circles. + Defaults to 0.8. + """ + if self.backend == 'matplotlib': + super().draw_circles( + center=center, + radius=radius, + face_colors=face_colors, + **kwargs) + elif self.backend == 'opencv': + if isinstance(face_colors, str): + face_colors = mmcv.color_val(face_colors) + self._image = cv2.circle(self._image, + (int(center[0]), int(center[1])), + int(radius), face_colors, -1) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_texts( + self, + texts: Union[str, List[str]], + positions: Union[np.ndarray, torch.Tensor], + font_sizes: Optional[Union[int, List[int]]] = None, + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + vertical_alignments: Union[str, List[str]] = 'top', + horizontal_alignments: Union[str, List[str]] = 'left', + bboxes: Optional[Union[dict, List[dict]]] = None, + **kwargs, + ) -> 'Visualizer': + """Draw single or multiple text boxes. + + Args: + texts (Union[str, List[str]]): Texts to draw. + positions (Union[np.ndarray, torch.Tensor]): The position to draw + the texts, which should have the same length with texts and + each dim contain x and y. + font_sizes (Union[int, List[int]], optional): The font size of + texts. ``font_sizes`` can have the same length with texts or + just single value. If ``font_sizes`` is single value, all the + texts will have the same font size. Defaults to None. + colors (Union[str, tuple, List[str], List[tuple]]): The colors + of texts. ``colors`` can have the same length with texts or + just single value. If ``colors`` is single value, all the + texts will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g. + vertical_alignments (Union[str, List[str]]): The verticalalignment + of texts. verticalalignment controls whether the y positional + argument for the text indicates the bottom, center or top side + of the text bounding box. + ``vertical_alignments`` can have the same length with + texts or just single value. If ``vertical_alignments`` is + single value, all the texts will have the same + verticalalignment. verticalalignment can be 'center' or + 'top', 'bottom' or 'baseline'. Defaults to 'top'. + horizontal_alignments (Union[str, List[str]]): The + horizontalalignment of texts. Horizontalalignment controls + whether the x positional argument for the text indicates the + left, center or right side of the text bounding box. + ``horizontal_alignments`` can have + the same length with texts or just single value. + If ``horizontal_alignments`` is single value, all the texts + will have the same horizontalalignment. Horizontalalignment + can be 'center','right' or 'left'. Defaults to 'left'. + font_families (Union[str, List[str]]): The font family of + texts. ``font_families`` can have the same length with texts or + just single value. If ``font_families`` is single value, all + the texts will have the same font family. + font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy' + or 'monospace'. Defaults to 'sans-serif'. + bboxes (Union[dict, List[dict]], optional): The bounding box of the + texts. If bboxes is None, there are no bounding box around + texts. ``bboxes`` can have the same length with texts or + just single value. If ``bboxes`` is single value, all + the texts will have the same bbox. Reference to + https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch + for more details. Defaults to None. + font_properties (Union[FontProperties, List[FontProperties]], optional): + The font properties of texts. FontProperties is + a ``font_manager.FontProperties()`` object. + If you want to draw Chinese texts, you need to prepare + a font file that can show Chinese characters properly. + For example: `simhei.ttf`, `simsun.ttc`, `simkai.ttf` and so on. + Then set ``font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file')`` + ``font_properties`` can have the same length with texts or + just single value. If ``font_properties`` is single value, + all the texts will have the same font properties. + Defaults to None. + `New in version 0.6.0.` + """ # noqa: E501 + + if self.backend == 'matplotlib': + super().draw_texts( + texts=texts, + positions=positions, + font_sizes=font_sizes, + colors=colors, + vertical_alignments=vertical_alignments, + horizontal_alignments=horizontal_alignments, + bboxes=bboxes, + **kwargs) + + elif self.backend == 'opencv': + font_scale = max(0.1, font_sizes / 30) + thickness = max(1, font_sizes // 15) + + text_size, text_baseline = cv2.getTextSize(texts, + cv2.FONT_HERSHEY_DUPLEX, + font_scale, thickness) + + x = int(positions[0]) + if horizontal_alignments == 'right': + x = max(0, x - text_size[0]) + y = int(positions[1]) + if vertical_alignments == 'top': + y = min(self.height, y + text_size[1]) + + if bboxes is not None: + bbox_color = bboxes[0]['facecolor'] + if isinstance(bbox_color, str): + bbox_color = mmcv.color_val(bbox_color) + + y = y - text_baseline // 2 + self._image = cv2.rectangle( + self._image, (x, y - text_size[1] - text_baseline // 2), + (x + text_size[0], y + text_baseline // 2), bbox_color, + cv2.FILLED) + + self._image = cv2.putText(self._image, texts, (x, y), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, + colors, thickness - 1) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_bboxes(self, + bboxes: Union[np.ndarray, torch.Tensor], + edge_colors: Union[str, tuple, List[str], + List[tuple]] = 'g', + line_widths: Union[Union[int, float], + List[Union[int, float]]] = 2, + **kwargs) -> 'Visualizer': + """Draw single or multiple bboxes. + + Args: + bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with + the format of(x1,y1,x2,y2). + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of bboxes. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, all + the lines will have the same colors. Refer to `matplotlib. + colors` for full list of formats that are accepted. + Defaults to 'g'. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Defaults to None. + alpha (Union[int, float]): The transparency of bboxes. + Defaults to 0.8. + """ + if self.backend == 'matplotlib': + super().draw_bboxes( + bboxes=bboxes, + edge_colors=edge_colors, + line_widths=line_widths, + **kwargs) + + elif self.backend == 'opencv': + self._image = mmcv.imshow_bboxes( + self._image, + bboxes, + edge_colors, + top_k=-1, + thickness=line_widths, + show=False) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_lines(self, + x_datas: Union[np.ndarray, torch.Tensor], + y_datas: Union[np.ndarray, torch.Tensor], + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_widths: Union[Union[int, float], + List[Union[int, float]]] = 2, + **kwargs) -> 'Visualizer': + """Draw single or multiple line segments. + + Args: + x_datas (Union[np.ndarray, torch.Tensor]): The x coordinate of + each line' start and end points. + y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of + each line' start and end points. + colors (Union[str, tuple, List[str], List[tuple]]): The colors of + lines. ``colors`` can have the same length with lines or just + single value. If ``colors`` is single value, all the lines + will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g'. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + """ + if self.backend == 'matplotlib': + super().draw_lines( + x_datas=x_datas, + y_datas=y_datas, + colors=colors, + line_widths=line_widths, + **kwargs) + + elif self.backend == 'opencv': + + self._image = cv2.line( + self._image, (x_datas[0], y_datas[0]), + (x_datas[1], y_datas[1]), + colors, + thickness=line_widths) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def draw_polygons(self, + polygons: Union[Union[np.ndarray, torch.Tensor], + List[Union[np.ndarray, torch.Tensor]]], + edge_colors: Union[str, tuple, List[str], + List[tuple]] = 'g', + **kwargs) -> 'Visualizer': + """Draw single or multiple bboxes. + + Args: + polygons (Union[Union[np.ndarray, torch.Tensor],\ + List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw + with the format of (x1,y1,x2,y2,...,xn,yn). + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of polygons. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Refer to + `matplotlib.colors` for full list of formats that are accepted. + Defaults to 'g. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to '-'. + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Defaults to None. + alpha (Union[int, float]): The transparency of polygons. + Defaults to 0.8. + """ + if self.backend == 'matplotlib': + super().draw_polygons( + polygons=polygons, edge_colors=edge_colors, **kwargs) + + elif self.backend == 'opencv': + + self._image = cv2.fillConvexPoly(self._image, polygons, + edge_colors) + else: + raise ValueError(f'got unsupported backend {self.backend}') + + @master_only + def show(self, + drawn_img: Optional[np.ndarray] = None, + win_name: str = 'image', + wait_time: float = 0., + continue_key=' ') -> None: + """Show the drawn image. + + Args: + drawn_img (np.ndarray, optional): The image to show. If drawn_img + is None, it will show the image got by Visualizer. Defaults + to None. + win_name (str): The image title. Defaults to 'image'. + wait_time (float): Delay in seconds. 0 is the special + value that means "forever". Defaults to 0. + continue_key (str): The key for users to continue. Defaults to + the space key. + """ + if self.backend == 'matplotlib': + super().show( + drawn_img=drawn_img, + win_name=win_name, + wait_time=wait_time, + continue_key=continue_key) + + elif self.backend == 'opencv': + # Keep images are shown in the same window, and the title of window + # will be updated with `win_name`. + if not hasattr(self, win_name): + self._cv_win_name = win_name + cv2.namedWindow(winname=f'{id(self)}') + cv2.setWindowTitle(f'{id(self)}', win_name) + else: + cv2.setWindowTitle(f'{id(self)}', win_name) + shown_img = self.get_image() if drawn_img is None else drawn_img + cv2.imshow(str(id(self)), mmcv.bgr2rgb(shown_img)) + cv2.waitKey(int(np.ceil(wait_time * 1000))) + else: + raise ValueError(f'got unsupported backend {self.backend}') diff --git a/mmpose/visualization/simcc_vis.py b/mmpose/visualization/simcc_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5b602fb5c4ffe2a46ddb2cf09a2cd4501b1664 --- /dev/null +++ b/mmpose/visualization/simcc_vis.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import cv2 as cv +import numpy as np +import torch +from torchvision.transforms import ToPILImage + + +class SimCCVisualizer: + + def draw_instance_xy_heatmap(self, + heatmap: torch.Tensor, + overlaid_image: Optional[np.ndarray], + n: int = 20, + mix: bool = True, + weight: float = 0.5): + """Draw heatmaps of GT or prediction. + + Args: + heatmap (torch.Tensor): Tensor of heatmap. + overlaid_image (np.ndarray): The image to draw. + n (int): Number of keypoint, up to 20. + mix (bool):Whether to merge heatmap and original image. + weight (float): Weight of original image during fusion. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + heatmap2d = heatmap.data.max(0, keepdim=True)[0] + xy_heatmap, K = self.split_simcc_xy(heatmap) + K = K if K <= n else n + blank_size = tuple(heatmap.size()[1:]) + maps = {'x': [], 'y': []} + for i in xy_heatmap: + x, y = self.draw_1d_heatmaps(i['x']), self.draw_1d_heatmaps(i['y']) + maps['x'].append(x) + maps['y'].append(y) + white = self.creat_blank(blank_size, K) + map2d = self.draw_2d_heatmaps(heatmap2d) + if mix: + map2d = cv.addWeighted(overlaid_image, 1 - weight, map2d, weight, + 0) + self.image_cover(white, map2d, int(blank_size[1] * 0.1), + int(blank_size[0] * 0.1)) + white = self.add_1d_heatmaps(maps, white, blank_size, K) + return white + + def split_simcc_xy(self, heatmap: Union[np.ndarray, torch.Tensor]): + """Extract one-dimensional heatmap from two-dimensional heatmap and + calculate the number of keypoint.""" + size = heatmap.size() + k = size[0] if size[0] <= 20 else 20 + maps = [] + for _ in range(k): + xy_dict = {} + single_heatmap = heatmap[_] + xy_dict['x'], xy_dict['y'] = self.merge_maps(single_heatmap) + maps.append(xy_dict) + return maps, k + + def merge_maps(self, map_2d): + """Synthesis of one-dimensional heatmap.""" + x = map_2d.data.max(0, keepdim=True)[0] + y = map_2d.data.max(1, keepdim=True)[0] + return x, y + + def draw_1d_heatmaps(self, heatmap_1d): + """Draw one-dimensional heatmap.""" + size = heatmap_1d.size() + length = max(size) + np_heatmap = ToPILImage()(heatmap_1d).convert('RGB') + cv_img = cv.cvtColor(np.asarray(np_heatmap), cv.COLOR_RGB2BGR) + if size[0] < size[1]: + cv_img = cv.resize(cv_img, (length, 15)) + else: + cv_img = cv.resize(cv_img, (15, length)) + single_map = cv.applyColorMap(cv_img, cv.COLORMAP_JET) + return single_map + + def creat_blank(self, + size: Union[list, tuple], + K: int = 20, + interval: int = 10): + """Create the background.""" + blank_height = int( + max(size[0] * 2, size[0] * 1.1 + (K + 1) * (15 + interval))) + blank_width = int( + max(size[1] * 2, size[1] * 1.1 + (K + 1) * (15 + interval))) + blank = np.zeros((blank_height, blank_width, 3), np.uint8) + blank.fill(255) + return blank + + def draw_2d_heatmaps(self, heatmap_2d): + """Draw a two-dimensional heatmap fused with the original image.""" + np_heatmap = ToPILImage()(heatmap_2d).convert('RGB') + cv_img = cv.cvtColor(np.asarray(np_heatmap), cv.COLOR_RGB2BGR) + map_2d = cv.applyColorMap(cv_img, cv.COLORMAP_JET) + return map_2d + + def image_cover(self, background: np.ndarray, foreground: np.ndarray, + x: int, y: int): + """Paste the foreground on the background.""" + fore_size = foreground.shape + background[y:y + fore_size[0], x:x + fore_size[1]] = foreground + return background + + def add_1d_heatmaps(self, + maps: dict, + background: np.ndarray, + map2d_size: Union[tuple, list], + K: int, + interval: int = 10): + """Paste one-dimensional heatmaps onto the background in turn.""" + y_startpoint, x_startpoint = [int(1.1*map2d_size[1]), + int(0.1*map2d_size[0])],\ + [int(0.1*map2d_size[1]), + int(1.1*map2d_size[0])] + x_startpoint[1] += interval * 2 + y_startpoint[0] += interval * 2 + add = interval + 10 + for i in range(K): + self.image_cover(background, maps['x'][i], x_startpoint[0], + x_startpoint[1]) + cv.putText(background, str(i), + (x_startpoint[0] - 30, x_startpoint[1] + 10), + cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) + self.image_cover(background, maps['y'][i], y_startpoint[0], + y_startpoint[1]) + cv.putText(background, str(i), + (y_startpoint[0], y_startpoint[1] - 5), + cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) + x_startpoint[1] += add + y_startpoint[0] += add + return background[:x_startpoint[1] + y_startpoint[1] + + 1, :y_startpoint[0] + x_startpoint[0] + 1] diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d065db1cd80c874b9fa93b84752575dd33e457b --- /dev/null +++ b/mmpretrain/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +from .apis import * # noqa: F401, F403 +from .version import __version__ + +mmcv_minimum_version = '2.0.0' +mmcv_maximum_version = '2.1.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.8.0' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<{mmengine_maximum_version}.' + +__all__ = ['__version__'] diff --git a/mmpretrain/__pycache__/__init__.cpython-38.pyc b/mmpretrain/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b860b5bc183e15e5c451735306cb3d0e555b03 Binary files /dev/null and b/mmpretrain/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/__pycache__/registry.cpython-38.pyc b/mmpretrain/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d0a1f36601180e535caba8842463df51bddd032 Binary files /dev/null and b/mmpretrain/__pycache__/registry.cpython-38.pyc differ diff --git a/mmpretrain/__pycache__/version.cpython-38.pyc b/mmpretrain/__pycache__/version.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..129ab5cbf1c760ae3dd5425ada9c2f393f432b2d Binary files /dev/null and b/mmpretrain/__pycache__/version.cpython-38.pyc differ diff --git a/mmpretrain/apis/__init__.py b/mmpretrain/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbf443772a983c41f7273124f843bdfbb7f0f46 --- /dev/null +++ b/mmpretrain/apis/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseInferencer +from .feature_extractor import FeatureExtractor +from .image_caption import ImageCaptionInferencer +from .image_classification import ImageClassificationInferencer +from .image_retrieval import ImageRetrievalInferencer +from .model import (ModelHub, get_model, inference_model, init_model, + list_models) +from .multimodal_retrieval import (ImageToTextRetrievalInferencer, + TextToImageRetrievalInferencer) +from .nlvr import NLVRInferencer +from .visual_grounding import VisualGroundingInferencer +from .visual_question_answering import VisualQuestionAnsweringInferencer + +__all__ = [ + 'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub', + 'ImageClassificationInferencer', 'ImageRetrievalInferencer', + 'FeatureExtractor', 'ImageCaptionInferencer', + 'TextToImageRetrievalInferencer', 'VisualGroundingInferencer', + 'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer', + 'BaseInferencer', 'NLVRInferencer' +] diff --git a/mmpretrain/apis/__pycache__/__init__.cpython-38.pyc b/mmpretrain/apis/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65bf34ee4d103bc7bda7306c6de0eeb3ff9e8ef0 Binary files /dev/null and b/mmpretrain/apis/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/base.cpython-38.pyc b/mmpretrain/apis/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbbf2a9206b4c9cd8b7d92e211afe41f2eccae7d Binary files /dev/null and b/mmpretrain/apis/__pycache__/base.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/feature_extractor.cpython-38.pyc b/mmpretrain/apis/__pycache__/feature_extractor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71ef9eb92b29c22d16482e623d7df6722545a017 Binary files /dev/null and b/mmpretrain/apis/__pycache__/feature_extractor.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/image_caption.cpython-38.pyc b/mmpretrain/apis/__pycache__/image_caption.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97405f793331b9d57c219ab8be4da944ddc92418 Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_caption.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/image_classification.cpython-38.pyc b/mmpretrain/apis/__pycache__/image_classification.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c15984ff844b44847a60346c962787c07c682b6a Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_classification.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/image_retrieval.cpython-38.pyc b/mmpretrain/apis/__pycache__/image_retrieval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fefb69909a2d7b18db87b3fd24fcc0cceac79c87 Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_retrieval.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/model.cpython-38.pyc b/mmpretrain/apis/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ab396d53f85702d40e68614323c576847384579 Binary files /dev/null and b/mmpretrain/apis/__pycache__/model.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-38.pyc b/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d02410fb5c709d99c598a172b4ef824f5aea91ea Binary files /dev/null and b/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/nlvr.cpython-38.pyc b/mmpretrain/apis/__pycache__/nlvr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68cd7897dd000482a0a20c3360d8313132873732 Binary files /dev/null and b/mmpretrain/apis/__pycache__/nlvr.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/visual_grounding.cpython-38.pyc b/mmpretrain/apis/__pycache__/visual_grounding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56aa69cb83e444180b35cc43d1bc2e741c3a963d Binary files /dev/null and b/mmpretrain/apis/__pycache__/visual_grounding.cpython-38.pyc differ diff --git a/mmpretrain/apis/__pycache__/visual_question_answering.cpython-38.pyc b/mmpretrain/apis/__pycache__/visual_question_answering.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8befd7f2e8544ae5695c64be347b5b3e67fdbee4 Binary files /dev/null and b/mmpretrain/apis/__pycache__/visual_question_answering.cpython-38.pyc differ diff --git a/mmpretrain/apis/base.py b/mmpretrain/apis/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7bff6bd18675a3a0996dcd09081a15728311657f --- /dev/null +++ b/mmpretrain/apis/base.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from math import ceil +from typing import Callable, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.config import Config +from mmengine.dataset import default_collate +from mmengine.fileio import get_file_backend +from mmengine.model import BaseModel +from mmengine.runner import load_checkpoint + +from mmpretrain.structures import DataSample +from mmpretrain.utils import track +from .model import get_model, list_models + +ModelType = Union[BaseModel, str, Config] +InputType = Union[str, np.ndarray, list] + + +class BaseInferencer: + """Base inferencer for various tasks. + + The BaseInferencer provides the standard workflow for inference as follows: + + 1. Preprocess the input data by :meth:`preprocess`. + 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` + assumes the model inherits from :class:`mmengine.models.BaseModel` and + will call `model.test_step` in :meth:`forward` by default. + 3. Visualize the results by :meth:`visualize`. + 4. Postprocess and return the results by :meth:`postprocess`. + + When we call the subclasses inherited from BaseInferencer (not overriding + ``__call__``), the workflow will be executed in order. + + All subclasses of BaseInferencer could define the following class + attributes for customization: + + - ``preprocess_kwargs``: The keys of the kwargs that will be passed to + :meth:`preprocess`. + - ``forward_kwargs``: The keys of the kwargs that will be passed to + :meth:`forward` + - ``visualize_kwargs``: The keys of the kwargs that will be passed to + :meth:`visualize` + - ``postprocess_kwargs``: The keys of the kwargs that will be passed to + :meth:`postprocess` + + All attributes mentioned above should be a ``set`` of keys (strings), + and each key should not be duplicated. Actually, :meth:`__call__` will + dispatch all the arguments to the corresponding methods according to the + ``xxx_kwargs`` mentioned above. + + Subclasses inherited from ``BaseInferencer`` should implement + :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: + + - _init_pipeline: Return a callable object to preprocess the input data. + - visualize: Visualize the results returned by :meth:`forward`. + - postprocess: Postprocess the results returned by :meth:`forward` and + :meth:`visualize`. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``cls.list_models()`` and you can also query it in + :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = set() + postprocess_kwargs: set = set() + + def __init__(self, + model: ModelType, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + device_map=None, + offload_folder=None, + **kwargs) -> None: + + if isinstance(model, BaseModel): + if isinstance(pretrained, str): + load_checkpoint(model, pretrained, map_location='cpu') + if device_map is not None: + from .utils import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_folder=offload_folder) + elif device is not None: + model.to(device) + else: + model = get_model( + model, + pretrained, + device=device, + device_map=device_map, + offload_folder=offload_folder, + **kwargs) + + model.eval() + + self.config = model._config + self.model = model + self.pipeline = self._init_pipeline(self.config) + self.visualizer = None + + def __call__( + self, + inputs, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs, + ) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Key words arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess( + ori_inputs, batch_size=batch_size, **preprocess_kwargs) + preds = [] + for data in track( + inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)): + preds.extend(self.forward(data, **forward_kwargs)) + visualization = self.visualize(ori_inputs, preds, **visualize_kwargs) + results = self.postprocess(preds, visualization, return_datasamples, + **postprocess_kwargs) + return results + + def _inputs_to_list(self, inputs: InputType) -> list: + """Preprocess the inputs to a list. + + Cast the input data to a list of data. + + - list or tuple: return inputs + - str: + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + - other: return a list with one item. + + Args: + inputs (str | array | list): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if isinstance(inputs, str): + backend = get_file_backend(inputs) + if hasattr(backend, 'isdir') and backend.isdir(inputs): + # Backends like HttpsBackend do not implement `isdir`, so only + # those backends that implement `isdir` could accept the inputs + # as a directory + file_list = backend.list_dir_or_file(inputs, list_dir=False) + inputs = [ + backend.join_path(inputs, file) for file in file_list + ] + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + + def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs): + """Process the inputs into a model-feedable format. + + Customize your preprocess by overriding this method. Preprocess should + return an iterable object, of which each item will be used as the + input of ``model.test_step``. + + ``BaseInferencer.preprocess`` will return an iterable chunked data, + which will be used in __call__ like this: + + .. code-block:: python + + def __call__(self, inputs, batch_size=1, **kwargs): + chunked_data = self.preprocess(inputs, batch_size, **kwargs) + for batch in chunked_data: + preds = self.forward(batch, **kwargs) + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``default_collate``. + """ + chunked_data = self._get_chunk_data( + map(self.pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], **kwargs): + """Feed the inputs to the model.""" + return self.model.test_step(inputs) + + def visualize(self, + inputs: list, + preds: List[DataSample], + show: bool = False, + **kwargs) -> List[np.ndarray]: + """Visualize predictions. + + Customize your visualization by overriding this method. visualize + should return visualization results, which could be np.ndarray or any + other objects. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + + Returns: + List[np.ndarray]: Visualization results. + """ + if show: + raise NotImplementedError( + f'The `visualize` method of {self.__class__.__name__} ' + 'is not implemented.') + + @abstractmethod + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasample=False, + **kwargs, + ) -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Customize your postprocess by overriding this method. Make sure + ``postprocess`` will return a dict with visualization results and + inference results. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (np.ndarray): Visualized predictions. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + + @abstractmethod + def _init_pipeline(self, cfg: Config) -> Callable: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + + def _get_chunk_data(self, inputs: Iterable, chunk_size: int): + """Get batch data from dataset. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + processed_data = next(inputs_iter) + chunk_data.append(processed_data) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + + def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]: + """Dispatch kwargs to preprocess(), forward(), visualize() and + postprocess() according to the actual demands. + + Returns: + Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, + forward, visualize and postprocess respectively. + """ + # Ensure each argument only matches one function + method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \ + self.visualize_kwargs | self.postprocess_kwargs + + union_kwargs = method_kwargs | set(kwargs.keys()) + if union_kwargs != method_kwargs: + unknown_kwargs = union_kwargs - method_kwargs + raise ValueError( + f'unknown argument {unknown_kwargs} for `preprocess`, ' + '`forward`, `visualize` and `postprocess`') + + preprocess_kwargs = {} + forward_kwargs = {} + visualize_kwargs = {} + postprocess_kwargs = {} + + for key, value in kwargs.items(): + if key in self.preprocess_kwargs: + preprocess_kwargs[key] = value + if key in self.forward_kwargs: + forward_kwargs[key] = value + if key in self.visualize_kwargs: + visualize_kwargs[key] = value + if key in self.postprocess_kwargs: + postprocess_kwargs[key] = value + + return ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List models defined in metafile of corresponding packages. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern) diff --git a/mmpretrain/apis/feature_extractor.py b/mmpretrain/apis/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..ee14f92f489497dd036fe0567786a94207924d4a --- /dev/null +++ b/mmpretrain/apis/feature_extractor.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from .base import BaseInferencer, InputType +from .model import list_models + + +class FeatureExtractor(BaseInferencer): + """The inferencer for extract features. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``FeatureExtractor.list_models()`` and you can also query it in + :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import FeatureExtractor + >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3))) + >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0] + >>> for feat in feats: + >>> print(feat.shape) + torch.Size([256, 56, 56]) + torch.Size([512, 28, 28]) + torch.Size([1024, 14, 14]) + torch.Size([2048, 7, 7]) + """ # noqa: E501 + + def __call__(self, + inputs: InputType, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Other keyword arguments accepted by the `extract_feat` + method of the model. + + Returns: + tensor | Tuple[tensor]: The extracted features. + """ + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess(ori_inputs, batch_size=batch_size) + preds = [] + for data in inputs: + preds.extend(self.forward(data, **kwargs)) + + return preds + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], **kwargs): + inputs = self.model.data_preprocessor(inputs, False)['inputs'] + outputs = self.model.extract_feat(inputs, **kwargs) + + def scatter(feats, index): + if isinstance(feats, torch.Tensor): + return feats[index] + else: + # Sequence of tensor + return type(feats)([scatter(item, index) for item in feats]) + + results = [] + for i in range(inputs.shape[0]): + results.append(scatter(outputs, i)) + + return results + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self): + raise NotImplementedError( + "The FeatureExtractor doesn't support visualization.") + + def postprocess(self): + raise NotImplementedError( + "The FeatureExtractor doesn't need postprocessing.") + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern) diff --git a/mmpretrain/apis/image_caption.py b/mmpretrain/apis/image_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..c11c0d3044d9924aba159782309d2cc20f1745bc --- /dev/null +++ b/mmpretrain/apis/image_caption.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType +from .model import list_models + + +class ImageCaptionInferencer(BaseInferencer): + """The inferencer for image caption. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageCaptionInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageCaptionInferencer + >>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption') + >>> inferencer('demo/cat-dog.png')[0] + {'pred_caption': 'a puppy and a cat sitting on a blanket'} + """ # noqa: E501 + + visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'} + + def __call__(self, + images: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(images, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_caption( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({'pred_caption': data_sample.get('pred_caption')}) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Caption') diff --git a/mmpretrain/apis/image_classification.py b/mmpretrain/apis/image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..a20218071c7afc90c6a46d61b5ed3a8fee5bc012 --- /dev/null +++ b/mmpretrain/apis/image_classification.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType, ModelType +from .model import list_models + + +class ImageClassificationInferencer(BaseInferencer): + """The inferencer for image classification. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageClassificationInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + 1. Use a pre-trained model in MMPreTrain to inference an image. + + >>> from mmpretrain import ImageClassificationInferencer + >>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k') + >>> inferencer('demo/demo.JPEG') + [{'pred_score': array([...]), + 'pred_label': 65, + 'pred_score': 0.6649367809295654, + 'pred_class': 'sea snake'}] + + 2. Use a config file and checkpoint to inference multiple images on GPU, + and save the visualization results in a folder. + + >>> from mmpretrain import ImageClassificationInferencer + >>> inferencer = ImageClassificationInferencer( + model='configs/resnet/resnet50_8xb32_in1k.py', + pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', + device='cuda') + >>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/") + """ # noqa: E501 + + visualize_kwargs: set = { + 'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir', + 'wait_time' + } + + def __init__(self, + model: ModelType, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + classes=None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + if classes is not None: + self.classes = classes + else: + self.classes = getattr(self.model, '_dataset_meta', + {}).get('classes') + + def __call__(self, + inputs: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor for visualization. This is helpful when the image is too + large or too small for visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__( + inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_cls( + image, + data_sample, + classes=self.classes, + resize=resize, + show=show, + wait_time=wait_time, + rescale_factor=rescale_factor, + draw_gt=False, + draw_pred=True, + draw_score=draw_score, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + pred_scores = data_sample.pred_score + pred_score = float(torch.max(pred_scores).item()) + pred_label = torch.argmax(pred_scores).item() + result = { + 'pred_scores': pred_scores.detach().cpu().numpy(), + 'pred_label': pred_label, + 'pred_score': pred_score, + } + if self.classes is not None: + result['pred_class'] = self.classes[pred_label] + results.append(result) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Classification') diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..deae1de7975b1e46ccf045bee7fec7ddcbfbfea4 --- /dev/null +++ b/mmpretrain/apis/image_retrieval.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import BaseDataset, Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType, ModelType +from .model import list_models + + +class ImageRetrievalInferencer(BaseInferencer): + """The inferencer for image to image retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageRetrievalInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader, BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The directory of the the images. + - list: A list of path of the images. + - dict: A config dict of the a prototype dataset. + - BaseDataset: A prototype dataset. + - DataLoader: A data loader to load the prototype data. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageRetrievalInferencer + >>> inferencer = ImageRetrievalInferencer( + ... 'resnet50-arcface_inshop', + ... prototype='./demo/', + ... prototype_cache='img_retri.pth') + >>> inferencer('demo/cat-dog.png', topk=2)[0][1] + {'match_score': tensor(0.4088, device='cuda:0'), + 'sample_idx': 3, + 'sample': {'img_path': './demo/dog.jpg'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__( + self, + model: ModelType, + prototype, + prototype_cache=None, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs, + ) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.prototype_dataset = self._prepare_prototype( + prototype, prototype_cache, prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A directory path of images + prototype = dict( + type='CustomDataset', with_label=False, data_root=prototype) + + if isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, dict): + # A config of dataset + from mmpretrain.registry import DATASETS + test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + dataset = DATASETS.build(prototype) + dataloader = build_dataloader(dataset) + elif isinstance(prototype, DataLoader): + dataset = prototype.dataset + dataloader = prototype + elif isinstance(prototype, BaseDataset): + dataset = prototype + dataloader = build_dataloader(dataset) + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + if cache is not None and Path(cache).exists(): + self.model.prototype = cache + else: + self.model.prototype = dataloader + self.model.prepare_prototype() + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + self.model.dump_prototype(path) + + def __call__(self, + inputs: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + sample = self.prototype_dataset.get_data_info( + sample_idx.item()) + sample_idx = sample.pop('sample_idx') + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'sample': sample + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Retrieval') diff --git a/mmpretrain/apis/model.py b/mmpretrain/apis/model.py new file mode 100644 index 0000000000000000000000000000000000000000..eba475e7f791f42eb9aec384afec947f72722f27 --- /dev/null +++ b/mmpretrain/apis/model.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import fnmatch +import os.path as osp +import re +import warnings +from os import PathLike +from pathlib import Path +from typing import List, Tuple, Union + +from mmengine.config import Config +from modelindex.load_model_index import load +from modelindex.models.Model import Model + + +class ModelHub: + """A hub to host the meta information of all pre-defined models.""" + _models_dict = {} + __mmpretrain_registered = False + + @classmethod + def register_model_index(cls, + model_index_path: Union[str, PathLike], + config_prefix: Union[str, PathLike, None] = None): + """Parse the model-index file and register all models. + + Args: + model_index_path (str | PathLike): The path of the model-index + file. + config_prefix (str | PathLike | None): The prefix of all config + file paths in the model-index file. + """ + model_index = load(str(model_index_path)) + model_index.build_models_with_collections() + + for metainfo in model_index.models: + model_name = metainfo.name.lower() + if metainfo.name in cls._models_dict: + raise ValueError( + 'The model name {} is conflict in {} and {}.'.format( + model_name, osp.abspath(metainfo.filepath), + osp.abspath(cls._models_dict[model_name].filepath))) + metainfo.config = cls._expand_config_path(metainfo, config_prefix) + cls._models_dict[model_name] = metainfo + + @classmethod + def get(cls, model_name): + """Get the model's metainfo by the model name. + + Args: + model_name (str): The name of model. + + Returns: + modelindex.models.Model: The metainfo of the specified model. + """ + cls._register_mmpretrain_models() + # lazy load config + metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) + if metainfo is None: + raise ValueError( + f'Failed to find model "{model_name}". please use ' + '`mmpretrain.list_models` to get all available names.') + if isinstance(metainfo.config, str): + metainfo.config = Config.fromfile(metainfo.config) + return metainfo + + @staticmethod + def _expand_config_path(metainfo: Model, + config_prefix: Union[str, PathLike] = None): + if config_prefix is None: + config_prefix = osp.dirname(metainfo.filepath) + + if metainfo.config is None or osp.isabs(metainfo.config): + config_path: str = metainfo.config + else: + config_path = osp.abspath(osp.join(config_prefix, metainfo.config)) + + return config_path + + @classmethod + def _register_mmpretrain_models(cls): + # register models in mmpretrain + if not cls.__mmpretrain_registered: + from importlib_metadata import distribution + root = distribution('mmpretrain').locate_file('mmpretrain') + model_index_path = root / '.mim' / 'model-index.yml' + ModelHub.register_model_index( + model_index_path, config_prefix=root / '.mim') + cls.__mmpretrain_registered = True + + @classmethod + def has(cls, model_name): + """Whether a model name is in the ModelHub.""" + return model_name in cls._models_dict + + +def get_model(model: Union[str, Config], + pretrained: Union[str, bool] = False, + device=None, + device_map=None, + offload_folder=None, + url_mapping: Tuple[str, str] = None, + **kwargs): + """Get a pre-defined model or create a model from config. + + Args: + model (str | Config): The name of model, the config file path or a + config instance. + pretrained (bool | str): When use name to specify model, you can + use ``True`` to load the pre-defined pretrained weights. And you + can also use a string to specify the path or link of weights to + load. Defaults to False. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + url_mapping (Tuple[str, str], optional): The mapping of pretrained + checkpoint link. For example, load checkpoint from a local dir + instead of download by ``('https://.*/', './checkpoint')``. + Defaults to None. + **kwargs: Other keyword arguments of the model config. + + Returns: + mmengine.model.BaseModel: The result model. + + Examples: + Get a ResNet-50 model and extract images feature: + + >>> import torch + >>> from mmpretrain import get_model + >>> inputs = torch.rand(16, 3, 224, 224) + >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) + >>> feats = model.extract_feat(inputs) + >>> for feat in feats: + ... print(feat.shape) + torch.Size([16, 256]) + torch.Size([16, 512]) + torch.Size([16, 1024]) + torch.Size([16, 2048]) + + Get Swin-Transformer model with pre-trained weights and inference: + + >>> from mmpretrain import get_model, inference_model + >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) + >>> result = inference_model(model, 'demo/demo.JPEG') + >>> print(result['pred_class']) + 'sea snake' + """ # noqa: E501 + if device_map is not None: + from .utils import dispatch_model + dispatch_model._verify_require() + + metainfo = None + if isinstance(model, Config): + config = copy.deepcopy(model) + if pretrained is True and 'load_from' in config: + pretrained = config.load_from + elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py': + config = Config.fromfile(model) + if pretrained is True and 'load_from' in config: + pretrained = config.load_from + elif isinstance(model, str): + metainfo = ModelHub.get(model) + config = metainfo.config + if pretrained is True and metainfo.weights is not None: + pretrained = metainfo.weights + else: + raise TypeError('model must be a name, a path or a Config object, ' + f'but got {type(config)}') + + if pretrained is True: + warnings.warn('Unable to find pre-defined checkpoint of the model.') + pretrained = None + elif pretrained is False: + pretrained = None + + if kwargs: + config.merge_from_dict({'model': kwargs}) + config.model.setdefault('data_preprocessor', + config.get('data_preprocessor', None)) + + from mmengine.registry import DefaultScope + + from mmpretrain.registry import MODELS + with DefaultScope.overwrite_default_scope('mmpretrain'): + model = MODELS.build(config.model) + + dataset_meta = {} + if pretrained: + # Mapping the weights to GPU may cause unexpected video memory leak + # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 + from mmengine.runner import load_checkpoint + if url_mapping is not None: + pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained) + checkpoint = load_checkpoint(model, pretrained, map_location='cpu') + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmpretrain 1.x + dataset_meta = checkpoint['meta']['dataset_meta'] + elif 'CLASSES' in checkpoint.get('meta', {}): + # mmcls 0.x + dataset_meta = {'classes': checkpoint['meta']['CLASSES']} + + if len(dataset_meta) == 0 and 'test_dataloader' in config: + from mmpretrain.registry import DATASETS + dataset_class = DATASETS.get(config.test_dataloader.dataset.type) + dataset_meta = getattr(dataset_class, 'METAINFO', {}) + + if device_map is not None: + model = dispatch_model( + model, device_map=device_map, offload_folder=offload_folder) + elif device is not None: + model.to(device) + + model._dataset_meta = dataset_meta # save the dataset meta + model._config = config # save the config in the model + model._metainfo = metainfo # save the metainfo in the model + model.eval() + return model + + +def init_model(config, checkpoint=None, device=None, **kwargs): + """Initialize a classifier from config file (deprecated). + + It's only for compatibility, please use :func:`get_model` instead. + + Args: + config (str | :obj:`mmengine.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + **kwargs: Other keyword arguments of the model config. + + Returns: + nn.Module: The constructed model. + """ + return get_model(config, checkpoint, device, **kwargs) + + +def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]: + """List all models available in MMPretrain. + + Args: + pattern (str | None): A wildcard pattern to match model names. + Defaults to None. + exclude_patterns (list | None): A list of wildcard patterns to + exclude names from the matched names. Defaults to None. + task (str | none): The evaluation task of the model. + + Returns: + List[str]: a list of model names. + + Examples: + List all models: + + >>> from mmpretrain import list_models + >>> list_models() + + List ResNet-50 models on ImageNet-1k dataset: + + >>> from mmpretrain import list_models + >>> list_models('resnet*in1k') + ['resnet50_8xb32_in1k', + 'resnet50_8xb32-fp16_in1k', + 'resnet50_8xb256-rsb-a1-600e_in1k', + 'resnet50_8xb256-rsb-a2-300e_in1k', + 'resnet50_8xb256-rsb-a3-100e_in1k'] + + List Swin-Transformer models trained from stratch and exclude + Swin-Transformer-V2 models: + + >>> from mmpretrain import list_models + >>> list_models('swin', exclude_patterns=['swinv2', '*-pre']) + ['swin-base_16xb64_in1k', + 'swin-base_3rdparty_in1k', + 'swin-base_3rdparty_in1k-384', + 'swin-large_8xb8_cub-384px', + 'swin-small_16xb64_in1k', + 'swin-small_3rdparty_in1k', + 'swin-tiny_16xb64_in1k', + 'swin-tiny_3rdparty_in1k'] + + List all EVA models for image classification task. + + >>> from mmpretrain import list_models + >>> list_models('eva', task='Image Classification') + ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px', + 'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px', + 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px', + 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px', + 'eva-l-p14_mim-pre_3rdparty_in1k-196px', + 'eva-l-p14_mim-pre_3rdparty_in1k-336px'] + """ + ModelHub._register_mmpretrain_models() + matches = set(ModelHub._models_dict.keys()) + + if pattern is not None: + # Always match keys with any postfix. + matches = set(fnmatch.filter(matches, pattern + '*')) + + exclude_patterns = exclude_patterns or [] + for exclude_pattern in exclude_patterns: + exclude = set(fnmatch.filter(matches, exclude_pattern + '*')) + matches = matches - exclude + + if task is not None: + task_matches = [] + for key in matches: + metainfo = ModelHub._models_dict[key] + if metainfo.results is None and task == 'null': + task_matches.append(key) + elif metainfo.results is None: + continue + elif task in [result.task for result in metainfo.results]: + task_matches.append(key) + matches = task_matches + + return sorted(list(matches)) + + +def inference_model(model, *args, **kwargs): + """Inference an image with the inferencer. + + Automatically select inferencer to inference according to the type of + model. It's a shortcut for a quick start, and for advanced usage, please + use the correspondding inferencer class. + + Here is the mapping from task to inferencer: + + - Image Classification: :class:`ImageClassificationInferencer` + - Image Retrieval: :class:`ImageRetrievalInferencer` + - Image Caption: :class:`ImageCaptionInferencer` + - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer` + - Visual Grounding: :class:`VisualGroundingInferencer` + - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer` + - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer` + - NLVR: :class:`NLVRInferencer` + + Args: + model (BaseModel | str | Config): The loaded model, the model + name or the config of the model. + *args: Positional arguments to call the inferencer. + **kwargs: Other keyword arguments to initialize and call the + correspondding inferencer. + + Returns: + result (dict): The inference results. + """ # noqa: E501 + from mmengine.model import BaseModel + + if isinstance(model, BaseModel): + metainfo = getattr(model, '_metainfo', None) + else: + metainfo = ModelHub.get(model) + + from inspect import signature + + from .image_caption import ImageCaptionInferencer + from .image_classification import ImageClassificationInferencer + from .image_retrieval import ImageRetrievalInferencer + from .multimodal_retrieval import (ImageToTextRetrievalInferencer, + TextToImageRetrievalInferencer) + from .nlvr import NLVRInferencer + from .visual_grounding import VisualGroundingInferencer + from .visual_question_answering import VisualQuestionAnsweringInferencer + task_mapping = { + 'Image Classification': ImageClassificationInferencer, + 'Image Retrieval': ImageRetrievalInferencer, + 'Image Caption': ImageCaptionInferencer, + 'Visual Question Answering': VisualQuestionAnsweringInferencer, + 'Visual Grounding': VisualGroundingInferencer, + 'Text-To-Image Retrieval': TextToImageRetrievalInferencer, + 'Image-To-Text Retrieval': ImageToTextRetrievalInferencer, + 'NLVR': NLVRInferencer, + } + + inferencer_type = None + + if metainfo is not None and metainfo.results is not None: + tasks = set(result.task for result in metainfo.results) + inferencer_type = [ + task_mapping.get(task) for task in tasks if task in task_mapping + ] + if len(inferencer_type) > 1: + inferencer_names = [cls.__name__ for cls in inferencer_type] + warnings.warn('The model supports multiple tasks, auto select ' + f'{inferencer_names[0]}, you can also use other ' + f'inferencer {inferencer_names} directly.') + inferencer_type = inferencer_type[0] + + if inferencer_type is None: + raise NotImplementedError('No available inferencer for the model') + + init_kwargs = { + k: kwargs.pop(k) + for k in list(kwargs) + if k in signature(inferencer_type).parameters.keys() + } + + inferencer = inferencer_type(model, **init_kwargs) + return inferencer(*args, **kwargs)[0] diff --git a/mmpretrain/apis/multimodal_retrieval.py b/mmpretrain/apis/multimodal_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb9c859aca309306c1e775b7a03bf3bbc1c7717 --- /dev/null +++ b/mmpretrain/apis/multimodal_retrieval.py @@ -0,0 +1,603 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union + +import mmengine +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import BaseDataset, Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from mmpretrain.utils import track +from .base import BaseInferencer +from .base import InputType as ImageType +from .base import ModelType +from .model import list_models + + +def filter_transforms(transforms: list, data_info: dict): + """Filter pipeline to avoid KeyError with partial data info.""" + data_info = deepcopy(data_info) + filtered_transforms = [] + for t in transforms: + try: + data_info = t(data_info) + filtered_transforms.append(t) + except KeyError: + pass + return filtered_transforms + + +class TextToImageRetrievalInferencer(BaseInferencer): + """The inferencer for text to image retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``TextToImageRetrievalInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader | BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The directory of the the images. + - list: A list of path of the images. + - dict: A config dict of the a prototype dataset. + - BaseDataset: A prototype dataset. + - DataLoader: A data loader to load the prototype data. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + fast_match (bool): Some algorithms will record extra image features for + further matching, which may consume large memory, set True to avoid + this behavior. Defaults to True. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import TextToImageRetrievalInferencer + >>> inferencer = TextToImageRetrievalInferencer( + ... 'blip-base_3rdparty_retrieval', + ... prototype='./demo/', + ... prototype_cache='t2i_retri.pth') + >>> inferencer('A cat and a dog.')[0] + {'match_score': tensor(0.3855, device='cuda:0'), + 'sample_idx': 1, + 'sample': {'img_path': './demo/cat-dog.png'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__(self, + model: ModelType, + prototype, + prototype_cache=None, + fast_match=True, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.img_pipeline, self.text_pipeline = self.pipeline + + if hasattr(self.model, 'fast_match'): + self.model.fast_match = fast_match + + self.prototype_dataset = self._prepare_prototype( + prototype, prototype_cache, batch_size=prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A directory path of images + prototype = dict( + type='CustomDataset', with_label=False, data_root=prototype) + + if isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, dict): + # A config of dataset + from mmpretrain.registry import DATASETS + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + prototype.setdefault('pipeline', test_pipeline) + dataset = DATASETS.build(prototype) + dataloader = build_dataloader(dataset) + elif isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, DataLoader): + dataset = prototype.dataset + dataloader = prototype + elif isinstance(prototype, BaseDataset): + dataset = prototype + dataloader = build_dataloader(dataset) + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + if cache is not None and Path(cache).exists(): + self.prototype = torch.load(cache) + else: + prototype = [] + for data_batch in track(dataloader, 'Prepare prototype...'): + with torch.no_grad(): + data_batch = self.model.data_preprocessor( + data_batch, False) + feats = self.model._run_forward(data_batch, mode='tensor') + prototype.append(feats) + prototype = { + k: torch.cat([d[k] for d in prototype]) + for k in prototype[0] + } + self.prototype = prototype + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + torch.save(self.prototype, path) + + def __call__(self, + inputs: ImageType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + @torch.no_grad() + def forward(self, data: dict, **kwargs): + """Feed the inputs to the model.""" + data = self.model.data_preprocessor(data, False) + data_samples = data['data_samples'] + feats = self.prototype.copy() + feats.update(self.model.extract_feat(data_samples=data_samples)) + return self.model.predict_all(feats, data_samples, cal_i2t=False)[0] + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg] + img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)} + text_info = {'text': 'example'} + img_pipeline = Compose(filter_transforms(test_transfroms, img_info)) + text_pipeline = Compose(filter_transforms(test_transfroms, text_info)) + return img_pipeline, text_pipeline + + def preprocess(self, inputs: List[str], batch_size: int = 1): + + def process_text(input_: str): + return self.text_pipeline({'text': input_}) + + chunked_data = self._get_chunk_data( + map(process_text, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[str], + preds: List[DataSample], + topk: int = 3, + figsize: Tuple[int, int] = (16, 9), + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)): + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_t2i_retrieval( + text, + data_sample, + self.prototype_dataset, + topk=topk, + fig_cfg=dict(figsize=figsize), + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + sample = self.prototype_dataset.get_data_info( + sample_idx.item()) + sample_idx = sample.pop('sample_idx') + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'sample': sample + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Text-To-Image Retrieval') + + +class ImageToTextRetrievalInferencer(BaseInferencer): + """The inferencer for image to text retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageToTextRetrievalInferencer.list_models()`` and you can + also query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader, BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The file path to load the string list. + - list: A list of string. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + fast_match (bool): Some algorithms will record extra image features for + further matching, which may consume large memory, set True to avoid + this behavior. Defaults to True. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageToTextRetrievalInferencer + >>> inferencer = ImageToTextRetrievalInferencer( + ... 'blip-base_3rdparty_retrieval', + ... prototype=['cat', 'dog', 'snake', 'bird'], + ... prototype_cache='i2t_retri.pth') + >>> inferencer('demo/bird.JPEG')[0] + {'match_score': tensor(0.3855, device='cuda:0'), + 'sample_idx': 1, + 'sample': {'img_path': './demo/cat-dog.png'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__(self, + model: ModelType, + prototype, + prototype_cache=None, + fast_match=True, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.img_pipeline, self.text_pipeline = self.pipeline + + if hasattr(self.model, 'fast_match'): + self.model.fast_match = fast_match + + self.prototype_dataset = self._prepare_prototype( + prototype, cache=prototype_cache, batch_size=prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + [ + self.text_pipeline({ + 'sample_idx': i, + 'text': text + }) for i, text in enumerate(dataset) + ], + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A file path of a list of string + dataset = mmengine.list_from_file(prototype) + elif mmengine.utils.is_seq_of(prototype, str): + dataset = prototype + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + dataloader = build_dataloader(dataset) + + if cache is not None and Path(cache).exists(): + self.prototype = torch.load(cache) + else: + prototype = [] + for data_batch in track(dataloader, 'Prepare prototype...'): + with torch.no_grad(): + data_batch = self.model.data_preprocessor( + data_batch, False) + feats = self.model._run_forward(data_batch, mode='tensor') + prototype.append(feats) + prototype = { + k: torch.cat([d[k] for d in prototype]) + for k in prototype[0] + } + self.prototype = prototype + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + torch.save(self.prototype, path) + + def __call__(self, + inputs: ImageType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + @torch.no_grad() + def forward(self, data: dict, **kwargs): + """Feed the inputs to the model.""" + data = self.model.data_preprocessor(data, False) + feats = self.prototype.copy() + feats.update(self.model.extract_feat(images=data['images'])) + return self.model.predict_all( + feats, data['data_samples'], cal_t2i=False)[0] + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg] + img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)} + text_info = {'text': 'example'} + img_pipeline = Compose(filter_transforms(test_transfroms, img_info)) + text_pipeline = Compose(filter_transforms(test_transfroms, text_info)) + return img_pipeline, text_pipeline + + def preprocess(self, inputs: List[ImageType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.img_pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[ImageType], + preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_i2t_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + text = self.prototype_dataset[sample_idx.item()] + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'text': text + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image-To-Text Retrieval') diff --git a/mmpretrain/apis/nlvr.py b/mmpretrain/apis/nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..9977c3b06f36fa61a3cd2edf36077a993b2030cd --- /dev/null +++ b/mmpretrain/apis/nlvr.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + +InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str] +InputsType = Union[List[InputType], InputType] + + +class NLVRInferencer(BaseInferencer): + """The inferencer for Natural Language for Visual Reasoning. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``NLVRInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + """ + + visualize_kwargs: set = { + 'resize', 'draw_score', 'show', 'show_dir', 'wait_time' + } + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (tuple, List[tuple]): The input data tuples, every tuple + should include three items (left image, right image, text). + The image can be a path or numpy array. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + assert isinstance(inputs, (tuple, list)) + if isinstance(inputs, tuple): + inputs = [inputs] + for input_ in inputs: + assert isinstance(input_, tuple) + assert len(input_) == 3 + + return super().__call__( + inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + assert test_pipeline_cfg[0]['type'] == 'ApplyToList' + + list_pipeline = deepcopy(test_pipeline_cfg[0]) + if list_pipeline.scatter_key == 'img_path': + # Remove `LoadImageFromFile` + list_pipeline.transforms.pop(0) + list_pipeline.scatter_key = 'img' + + test_pipeline = Compose( + [TRANSFORMS.build(list_pipeline)] + + [TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]]) + return test_pipeline + + def preprocess(self, inputs: InputsType, batch_size: int = 1): + + def load_image(input_): + img1 = imread(input_[0]) + img2 = imread(input_[1]) + text = input_[2] + if img1 is None: + raise ValueError(f'Failed to read image {input_[0]}.') + if img2 is None: + raise ValueError(f'Failed to read image {input_[1]}.') + return dict( + img=[img1, img2], + img_shape=[img1.shape[:2], img2.shape[:2]], + ori_shape=[img1.shape[:2], img2.shape[:2]], + text=text, + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + pred_scores = data_sample.pred_score + pred_score = float(torch.max(pred_scores).item()) + pred_label = torch.argmax(pred_scores).item() + result = { + 'pred_scores': pred_scores.detach().cpu().numpy(), + 'pred_label': pred_label, + 'pred_score': pred_score, + } + results.append(result) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='NLVR') diff --git a/mmpretrain/apis/utils.py b/mmpretrain/apis/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83e76325472f6925f78c746e3a10f3a58b0e6de4 --- /dev/null +++ b/mmpretrain/apis/utils.py @@ -0,0 +1,270 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from collections import defaultdict +from contextlib import contextmanager +from itertools import chain +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.utils import require + + +@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/') +@require('accelerate') +def dispatch_model( + model, + device_map: Union[str, dict], + max_memory: Optional[dict] = None, + no_split_module_classes: Optional[List[str]] = None, + offload_folder: str = None, + offload_buffers: bool = False, + preload_module_classes: Optional[List[str]] = None, +): + """Split and dispatch a model across devices. + + The function depends on the `accelerate` package. Refers to + https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling + + Args: + model (torch.nn.Module): The model to dispatch. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + max_memory (dict | None): A dictionary device identifier to maximum + memory. Will default to the maximum memory available for each GPU + and the available CPU RAM if unset. Defaults to None. + no_split_module_classes (List[str] | None): A list of layer class names + that should never be split across device (for instance any layer + that has a residual connection). If None, try to get the settings + from the model class. Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + offload_buffers (bool): In the layers that are offloaded on the CPU + or the hard drive, whether or not to offload the buffers as + well as the parameters. Defaults to False. + preload_module_classes (List[str] | None): A list of classes whose + instances should load all their weights (even in the submodules) at + the beginning of the forward. This should only be used for classes + that have submodules which are registered but not called directly + during the forward, for instance if a `dense` linear layer is + registered, but at forward, `dense.weight` and `dense.bias` are + used in some operations instead of calling `dense` directly. + Defaults to None. + """ + from accelerate import dispatch_model, infer_auto_device_map + + # Check valid device_map string. + valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential'] + if isinstance(device_map, str) and device_map not in valid_map_option: + raise ValueError('If passing a string for `device_map`, please choose ' + f'from {valid_map_option}.') + + # Generate device map automatically + if isinstance(device_map, str): + if no_split_module_classes is None: + no_split_module_classes = getattr(model, '_no_split_modules', None) + if no_split_module_classes is None: + raise ValueError(f'{model.__class__.__name__} does not support ' + f"`device_map='{device_map}'` yet.") + + if device_map != 'sequential': + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=None, + low_zero=(device_map == 'balanced_low_0'), + ) + max_memory[0] *= 0.9 + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=None, + ) + + if 'disk' in device_map.values(): + if offload_folder is None: + raise ValueError( + 'The current `device_map` had weights offloaded to the disk. ' + 'Please provide an `offload_folder` for them.') + os.makedirs(offload_folder, exist_ok=True) + + main_device = next( + (d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu') + + model = dispatch_model( + model, + device_map=device_map, + main_device=main_device, + offload_dir=offload_folder, + offload_buffers=offload_buffers, + preload_module_classes=preload_module_classes, + ) + if hasattr(model, 'data_preprocessor'): + model.data_preprocessor._device = torch.device(main_device) + return model + + +@contextmanager +def init_empty_weights(include_buffers: bool = False): + """A context manager under which models are initialized with all parameters + on the meta device. + + With this context manager, we can create an empty model. Useful when just + initializing the model would blow the available RAM. + + Besides move the parameters to meta device, this method will also avoid + load checkpoint from `mmengine.runner.load_checkpoint` and + `transformers.PreTrainedModel.from_pretrained`. + + Modified from https://github.com/huggingface/accelerate + + Args: + include_buffers (bool): Whether put all buffers on the meta device + during initialization. + """ + device = torch.device('meta') + + # move parameter and buffer to meta device + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + # See https://github.com/huggingface/accelerate/pull/699 + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ['empty', 'zeros', 'ones', 'full'] + } + + def register_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs) + + def register_buffer(module, name, buffer, *args, **kwargs): + old_register_buffer(module, name, buffer, *args, **kwargs) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + + def wrapper(*args, **kwargs): + kwargs['device'] = device + return fn(*args, **kwargs) + + return wrapper + + # Patch load_checkpoint + import mmengine.runner.checkpoint as mmengine_load + old_load_checkpoint = mmengine_load.load_checkpoint + + def patch_load_checkpoint(*args, **kwargs): + return {} + + # Patch transformers from pretrained + try: + from transformers import PreTrainedModel + from transformers.models.auto.auto_factory import (AutoConfig, + _BaseAutoModelClass) + with_transformers = True + except ImportError: + with_transformers = False + + @classmethod + def patch_auto_model(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, + *model_args, **kwargs) + return cls.from_config(cfg) + + @classmethod + def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path, + *model_args, **kwargs) + return cls(cfg) + + if with_transformers: + old_pretrained_model = PreTrainedModel.from_pretrained + old_auto_model = _BaseAutoModelClass.from_pretrained + + try: + nn.Module.register_parameter = register_parameter + mmengine_load.load_checkpoint = patch_load_checkpoint + if with_transformers: + PreTrainedModel.from_pretrained = patch_pretrained_model + _BaseAutoModelClass.from_pretrained = patch_auto_model + if include_buffers: + nn.Module.register_buffer = register_buffer + for func in tensor_constructors_to_patch.keys(): + tensor_constructor = patch_tensor_constructor( + getattr(torch, func)) + setattr(torch, func, tensor_constructor) + yield + finally: + nn.Module.register_parameter = old_register_parameter + mmengine_load.load_checkpoint = old_load_checkpoint + if with_transformers: + PreTrainedModel.from_pretrained = old_pretrained_model + _BaseAutoModelClass.from_pretrained = old_auto_model + if include_buffers: + nn.Module.register_buffer = old_register_buffer + for func, ori in tensor_constructors_to_patch.items(): + setattr(torch, func, ori) + + +def compute_module_sizes( + model: nn.Module, + dtype: Union[str, torch.dtype, None] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None): + """Compute the size of each submodule of a given model.""" + + def get_dtype(dtype): + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + if dtype is not None: + assert issubclass(dtype, torch.dtype) + return dtype + + def dtype_bytes(dtype: torch.dtype): + if dtype is torch.bool: + return 1 + if dtype.is_floating_point: + return torch.finfo(dtype).bits / 8 + else: + return torch.iinfo(dtype).bits / 8 + + if dtype is not None: + dtype = get_dtype(dtype) + dtype_size = dtype_bytes(dtype) + + if special_dtypes is not None: + special_dtypes = { + key: dtype_bytes(dtype) + for key, dtype in special_dtypes.items() + } + + module_sizes = defaultdict(int) + for name, tensor in chain( + model.named_parameters(recurse=True), + model.named_buffers(recurse=True)): + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes[name] + elif dtype is None: + size = tensor.numel() * tensor.element_size() + else: + size = tensor.numel() * min(dtype_size, tensor.element_size()) + name_parts = name.split('.') + for idx in range(len(name_parts) + 1): + module_sizes['.'.join(name_parts[:idx])] += size + + return module_sizes diff --git a/mmpretrain/apis/visual_grounding.py b/mmpretrain/apis/visual_grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..0153d56f5ca10a32e9fd2ccabb0d15c1135e213d --- /dev/null +++ b/mmpretrain/apis/visual_grounding.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + + +class VisualGroundingInferencer(BaseInferencer): + """The inferencer for visual grounding. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``VisualGroundingInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import VisualGroundingInferencer + >>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco') + >>> inferencer('demo/cat-dog.png', 'dog')[0] + {'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])} + """ # noqa: E501 + + visualize_kwargs: set = { + 'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color' + } + + def __call__(self, + images: Union[str, np.ndarray, list], + texts: Union[str, list], + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + texts (str | list): The text to do visual grounding. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + line_width (int): The line width of the bbox. Defaults to 3. + bbox_color (str | tuple): The color of the bbox. + Defaults to 'green'. + + Returns: + list: The inference results. + """ + if not isinstance(images, (list, tuple)): + assert isinstance(texts, str) + inputs = [{'img': images, 'text': texts}] + else: + inputs = [] + for i in range(len(images)): + input_ = {'img': images[i], 'text': texts[i]} + inputs.append(input_) + + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[dict], batch_size: int = 1): + + def load_image(input_: dict): + img = imread(input_['img']) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return {**input_, 'img': img} + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[dict], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + line_width: int = 3, + bbox_color: Union[str, tuple] = 'green', + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_['img']) + if isinstance(input_['img'], str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_['img']).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_visual_grounding( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + line_width=line_width, + bbox_color=bbox_color, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({'pred_bboxes': data_sample.get('pred_bboxes')}) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Visual Grounding') diff --git a/mmpretrain/apis/visual_question_answering.py b/mmpretrain/apis/visual_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..616e1edc66709401df83cb5253590325e727aa98 --- /dev/null +++ b/mmpretrain/apis/visual_question_answering.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + + +class VisualQuestionAnsweringInferencer(BaseInferencer): + """The inferencer for visual question answering. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``VisualQuestionAnsweringInferencer.list_models()`` and you can + also query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import VisualQuestionAnsweringInferencer + >>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa') + >>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0] + {'question': "What's the animal next to the dog?", 'pred_answer': 'cat'} + """ # noqa: E501 + + visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'} + + def __call__(self, + images: Union[str, np.ndarray, list], + questions: Union[str, list], + return_datasamples: bool = False, + batch_size: int = 1, + objects: Optional[List[str]] = None, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + questions (str | list): The question to the correspondding image. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + objects (List[List[str]], optional): Some algorithms like OFA + fine-tuned VQA models requires extra object description list + for every image. Defaults to None. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + if not isinstance(images, (list, tuple)): + assert isinstance(questions, str) + inputs = [{'img': images, 'question': questions}] + if objects is not None: + assert isinstance(objects[0], str) + inputs[0]['objects'] = objects + else: + inputs = [] + for i in range(len(images)): + input_ = {'img': images[i], 'question': questions[i]} + if objects is not None: + input_['objects'] = objects[i] + inputs.append(input_) + + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[dict], batch_size: int = 1): + + def load_image(input_: dict): + img = imread(input_['img']) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return {**input_, 'img': img} + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[dict], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_['img']) + if isinstance(input_['img'], str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_['img']).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_vqa( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({ + 'question': data_sample.get('question'), + 'pred_answer': data_sample.get('pred_answer'), + }) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Visual Question Answering') diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py new file mode 100644 index 0000000000000000000000000000000000000000..7d074008cc204f4ac486dc04fb3f1c638fb9e161 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py new file mode 100644 index 0000000000000000000000000000000000000000..b687a06fef86827a76a472371f4de7dc2a5f2ac1 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmcv.transforms import (LoadImageFromFile, RandomApply, RandomFlip, + RandomGrayscale) +from mmengine.dataset import DefaultSampler, default_collate + +from mmpretrain.datasets import (ColorJitter, GaussianBlur, ImageNet, + MultiView, PackInputs, RandomResizedCrop) +from mmpretrain.models import SelfSupDataPreprocessor + +# dataset settings +dataset_type = 'ImageNet' +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=SelfSupDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +view_pipeline = [ + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomApply, + transforms=[ + dict( + type=ColorJitter, + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2) + ], + prob=0.8), + dict( + type=RandomGrayscale, + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989)), + dict( + type=GaussianBlur, + magnitude_range=(0.1, 2.0), + magnitude_std='inf', + prob=0.5), +] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=MultiView, num_views=2, transforms=[view_pipeline]), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=32, + num_workers=4, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=ImageNet, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..f64d5eac39902bdf8c95a0e7f7d589d069951df5 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmcv.transforms import LoadImageFromFile, RandomFlip +from mmengine.dataset.sampler import DefaultSampler + +from mmpretrain.datasets import ImageNet, PackInputs, RandomResizedCrop +from mmpretrain.models import SelfSupDataPreprocessor + +# dataset settings +dataset_type = 'ImageNet' +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=SelfSupDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + crop_ratio_range=(0.2, 1.0), + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=512, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=ImageNet, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py new file mode 100644 index 0000000000000000000000000000000000000000..85aeb1e2c131109f3f6d75d21e2cc1c782c82b7f --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (ImageNet, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, Resize) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=384, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=384, backend='pillow', interpolation='bicubic'), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/default_runtime.py b/mmpretrain/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c748eb84b3e50d7c6b30efaa87cd3c1f2f1827 --- /dev/null +++ b/mmpretrain/configs/_base_/default_runtime.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.visualization import LocalVisBackend + +from mmpretrain.engine.hooks import VisualizationHook +from mmpretrain.visualization import UniversalVisualizer + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=100), + + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), + + # validation results visualization, set True to enable it. + visualization=dict(type=VisualizationHook, enable=False), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict(type=UniversalVisualizer, vis_backends=vis_backends) + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# Do not need to specify default_scope with new config. Therefore set it to +# None to avoid BC-breaking. +default_scope = None diff --git a/mmpretrain/configs/_base_/models/convnext_base.py b/mmpretrain/configs/_base_/models/convnext_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6315b2f1966d2484739087e1e131fe8dd9a2ad56 --- /dev/null +++ b/mmpretrain/configs/_base_/models/convnext_base.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model import TruncNormalInit + +from mmpretrain.models import (ConvNeXt, CutMix, ImageClassifier, + LabelSmoothLoss, LinearClsHead, Mixup) + +# Model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=ConvNeXt, arch='base', drop_path_rate=0.5), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type=TruncNormalInit, layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0), + ]), +) diff --git a/mmpretrain/configs/_base_/models/mae_vit_base_p16.py b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py new file mode 100644 index 0000000000000000000000000000000000000000..9347d1e8810e553ef5563a96198794ec139ea3a4 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (MAE, MAEPretrainDecoder, MAEPretrainHead, + MAEViT, PixelReconstructionLoss) + +# model settings +model = dict( + type=MAE, + backbone=dict(type=MAEViT, arch='b', patch_size=16, mask_ratio=0.75), + neck=dict( + type=MAEPretrainDecoder, + patch_size=16, + in_chans=3, + embed_dim=768, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4., + ), + head=dict( + type=MAEPretrainHead, + norm_pix=True, + patch_size=16, + loss=dict(type=PixelReconstructionLoss, criterion='L2')), + init_cfg=[ + dict(type='Xavier', layer='Linear', distribution='uniform'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) diff --git a/mmpretrain/configs/_base_/models/resnet18.py b/mmpretrain/configs/_base_/models/resnet18.py new file mode 100644 index 0000000000000000000000000000000000000000..30b8f65148611c5602858b875b9be89b31f225cb --- /dev/null +++ b/mmpretrain/configs/_base_/models/resnet18.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, ResNet) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=ResNet, + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=512, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..60ccaa0e25ec69aa618430f51a60d949506fc406 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +# for batch in each gpu is 128, 8 gpu +# lr = 5e-4 * 128 * 8 / 512 = 0.001 +optim_wrapper = dict( + optimizer=dict( + type=AdamW, + lr=5e-4 * 1024 / 512, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + flat_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=1e-3, + by_epoch=True, + end=20, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type=CosineAnnealingLR, eta_min=1e-5, by_epoch=True, begin=20) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py new file mode 100644 index 0000000000000000000000000000000000000000..95afa2ad292c277a84aa274786ee34a9d6b8b0ef --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[30, 60, 90], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7e6171e2aeb20a94277e7ca4d02b2598d73b8e --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop + +from mmpretrain.engine.optimizers.lars import LARS + +# optimizer wrapper +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=LARS, lr=4.8, weight_decay=1e-6, momentum=0.9)) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict(type=CosineAnnealingLR, T_max=190, by_epoch=True, begin=10, end=200) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=200) diff --git a/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9c329abba18ecb5bad1875090f7de667a77391 --- /dev/null +++ b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmengine.dataset import DefaultSampler, default_collate +from mmengine.hooks import CheckpointHook +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet, + LoadImageFromFile, PackInputs, RandomFlip, + RandomResizedCropAndInterpolationWithTwoPic) +from mmpretrain.models import (BEiT, BEiTPretrainViT, BEiTV1Head, + CrossEntropyLoss, DALLEEncoder, + TwoNormDataPreprocessor) + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=TwoNormDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + second_mean=[-31.875, -31.875, -31.875], + second_std=[318.75, 318.75, 318.75], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, + hue=0.), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandomResizedCropAndInterpolationWithTwoPic, + size=224, + second_size=112, + interpolation='bicubic', + second_interpolation='lanczos', + scale=(0.08, 1.0)), + dict( + type=BEiTMaskGenerator, + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=None, + min_num_patches=16), + dict(type=PackInputs) +] +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) + +# model settings +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + out_type='raw', + layer_scale_init_value=0.1, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=None, + head=dict( + type=BEiTV1Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=DALLEEncoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa: E251 + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501 + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +default_hooks.update( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness.update(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..46be2ca6dcb0636469ca19dda51affd003a8b812 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +from mmpretrain.engine import EMAHook + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..a254ac8a84d94acdd1ec5f84059c7e75abc3cbc4 --- /dev/null +++ b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks import CheckpointHook +from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (EVA, CLIPGenerator, CosineSimilarityLoss, + MAEPretrainDecoder, MIMHead) + +# dataset settings +train_dataloader.batch_size = 256 + +# model settings +model.type = EVA +model.init_cfg = None +model.backbone.update(init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) +]) +model.neck.update( + type=MAEPretrainDecoder, + predict_feature_dim=512, + init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) +model.head = dict( + type=MIMHead, + loss=dict(type=CosineSimilarityLoss, shift_factor=2.0, scale_factor=2.0)) +model.target_generator = dict( + type=CLIPGenerator, + tokenizer_path= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa +) + +# optimizer wrapper +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) +find_unused_parameters = True + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(dict(seed=0, diff_rank_seed=True)) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..6cee3bc93fd8b1c65263ac415422b9b73628e88d --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=260, + by_epoch=True, + begin=40, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..f16d248b6988c924e8540a7782dabee4997baba1 --- /dev/null +++ b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32 import * + from .._base_.default_runtime import * + from .._base_.models.resnet18 import * + from .._base_.schedules.imagenet_bs256 import * diff --git a/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..09c738f219e561e5863dd2ce8246af005502bc83 --- /dev/null +++ b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_simclr import * + from .._base_.schedules.imagenet_lars_coslr_200e import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper + +from mmpretrain.engine.optimizers.lars import LARS +from mmpretrain.models.backbones.resnet import ResNet +from mmpretrain.models.heads.contrastive_head import ContrastiveHead +from mmpretrain.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmpretrain.models.necks.nonlinear_neck import NonLinearNeck +from mmpretrain.models.selfsup.simclr import SimCLR + +# dataset settings +train_dataloader.merge(dict(batch_size=256)) + +# model settings +model = dict( + type=SimCLR, + backbone=dict( + type=ResNet, + depth=50, + norm_cfg=dict(type='SyncBN'), + zero_init_residual=True), + neck=dict( + type=NonLinearNeck, # SimCLR non-linear neck + in_channels=2048, + hid_channels=2048, + out_channels=128, + num_layers=2, + with_avg_pool=True), + head=dict( + type=ContrastiveHead, + loss=dict(type=CrossEntropyLoss), + temperature=0.1), +) + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=LARS, lr=4.8, momentum=0.9, weight_decay=1e-6), + paramwise_cfg=dict( + custom_keys={ + 'bn': dict(decay_mult=0, lars_exclude=True), + 'bias': dict(decay_mult=0, lars_exclude=True), + # bn layer in ResNet block downsample module + 'downsample.1': dict(decay_mult=0, lars_exclude=True) + })) + +# runtime settings +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=10, max_keep_ckpts=3) diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b6be47dceb288acc2517a7ba2c890f2a38b671 --- /dev/null +++ b/mmpretrain/datasets/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL +from .base_dataset import BaseDataset +from .builder import build_dataset +from .caltech101 import Caltech101 +from .cifar import CIFAR10, CIFAR100 +from .cub import CUB +from .custom import CustomDataset +from .dataset_wrappers import KFoldDataset +from .dtd import DTD +from .fgvcaircraft import FGVCAircraft +from .flowers102 import Flowers102 +from .food101 import Food101 +from .imagenet import ImageNet, ImageNet21k +from .inshop import InShop +from .mnist import MNIST, FashionMNIST +from .multi_label import MultiLabelDataset +from .multi_task import MultiTaskDataset +from .nlvr2 import NLVR2 +from .oxfordiiitpet import OxfordIIITPet +from .places205 import Places205 +from .samplers import * # noqa: F401,F403 +from .stanfordcars import StanfordCars +from .sun397 import SUN397 +from .transforms import * # noqa: F401,F403 +from .voc import VOC + +__all__ = [ + 'BaseDataset', 'CIFAR10', 'CIFAR100', 'CUB', 'Caltech101', 'CustomDataset', + 'DTD', 'FGVCAircraft', 'FashionMNIST', 'Flowers102', 'Food101', 'ImageNet', + 'ImageNet21k', 'InShop', 'KFoldDataset', 'MNIST', 'MultiLabelDataset', + 'MultiTaskDataset', 'NLVR2', 'OxfordIIITPet', 'Places205', 'SUN397', + 'StanfordCars', 'VOC', 'build_dataset' +] + +if WITH_MULTIMODAL: + from .coco_caption import COCOCaption + from .coco_retrieval import COCORetrieval + from .coco_vqa import COCOVQA + from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA + from .flickr30k_caption import Flickr30kCaption + from .flickr30k_retrieval import Flickr30kRetrieval + from .gqa_dataset import GQA + from .nocaps import NoCaps + from .ocr_vqa import OCRVQA + from .refcoco import RefCOCO + from .scienceqa import ScienceQA + from .textvqa import TextVQA + from .visual_genome import VisualGenomeQA + from .vizwiz import VizWiz + from .vsr import VSR + + __all__.extend([ + 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', + 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', + 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', + 'VSR', 'VizWiz', 'OCRVQA' + ]) diff --git a/mmpretrain/datasets/__pycache__/__init__.cpython-38.pyc b/mmpretrain/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86138c4e807baf0d5df8517dc7daddfd17743079 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/base_dataset.cpython-38.pyc b/mmpretrain/datasets/__pycache__/base_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987fc2bbc7c5277a9d87cd9ad6cfc40f1a68b336 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/base_dataset.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/builder.cpython-38.pyc b/mmpretrain/datasets/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49c88bf76c4080738ef2ef9983116de35e01afba Binary files /dev/null and b/mmpretrain/datasets/__pycache__/builder.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/caltech101.cpython-38.pyc b/mmpretrain/datasets/__pycache__/caltech101.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13b26c064f946f5a9edcbc7f98c4deb11209923e Binary files /dev/null and b/mmpretrain/datasets/__pycache__/caltech101.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/categories.cpython-38.pyc b/mmpretrain/datasets/__pycache__/categories.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7637efbe554357d50e85c420574d63e9f18db8f Binary files /dev/null and b/mmpretrain/datasets/__pycache__/categories.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/cifar.cpython-38.pyc b/mmpretrain/datasets/__pycache__/cifar.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdff96ed0d0cdcc97fefc21d7e6ea00a152aba2e Binary files /dev/null and b/mmpretrain/datasets/__pycache__/cifar.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/coco_caption.cpython-38.pyc b/mmpretrain/datasets/__pycache__/coco_caption.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b5c4f877657db08f61ad957666cc88250418e21 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/coco_caption.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/coco_retrieval.cpython-38.pyc b/mmpretrain/datasets/__pycache__/coco_retrieval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56d7423d9fd3fa4bda43508c0ee830c5cdc89b2c Binary files /dev/null and b/mmpretrain/datasets/__pycache__/coco_retrieval.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/coco_vqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/coco_vqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd5f207c61d7dac58cf3c47d404355b282e61d0d Binary files /dev/null and b/mmpretrain/datasets/__pycache__/coco_vqa.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/cub.cpython-38.pyc b/mmpretrain/datasets/__pycache__/cub.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f896f234938963213e7797506039e46460b4b23 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/cub.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/custom.cpython-38.pyc b/mmpretrain/datasets/__pycache__/custom.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fef35c6b51eb9d211fefd68671d6bc3de1599a7 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/custom.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/dataset_wrappers.cpython-38.pyc b/mmpretrain/datasets/__pycache__/dataset_wrappers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d39db136982d2b0b589fc4913a170d5f514f0249 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/dataset_wrappers.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/dtd.cpython-38.pyc b/mmpretrain/datasets/__pycache__/dtd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dee44a6f64bbed91555b2c3fb83a1d867bf5aa60 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/dtd.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/fgvcaircraft.cpython-38.pyc b/mmpretrain/datasets/__pycache__/fgvcaircraft.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f80dd69f1597975f40557a904475510ec7df398f Binary files /dev/null and b/mmpretrain/datasets/__pycache__/fgvcaircraft.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/flamingo.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flamingo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f05cbfafe94af3dfa4a96d6b40e3c045e1f5d851 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flamingo.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/flickr30k_caption.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flickr30k_caption.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5957d30709875a8da03bd04be3b84845941fdbe Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flickr30k_caption.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/flickr30k_retrieval.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flickr30k_retrieval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d35880f3e9ed70debe4a0e2b226fe64ca739a184 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flickr30k_retrieval.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/flowers102.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flowers102.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44ac082ff98e4f43902dd42c1b2cd491ead1103a Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flowers102.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/food101.cpython-38.pyc b/mmpretrain/datasets/__pycache__/food101.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0db483d2f9daaaf127fc3ac5bb9c1233cd41869f Binary files /dev/null and b/mmpretrain/datasets/__pycache__/food101.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/gqa_dataset.cpython-38.pyc b/mmpretrain/datasets/__pycache__/gqa_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde140af94e82c03c8a7e93b347dcc30c0a04010 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/gqa_dataset.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/imagenet.cpython-38.pyc b/mmpretrain/datasets/__pycache__/imagenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa74c0f2a722575fd09fbd77b38d7503b94ccdc Binary files /dev/null and b/mmpretrain/datasets/__pycache__/imagenet.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/inshop.cpython-38.pyc b/mmpretrain/datasets/__pycache__/inshop.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2050be3585e316e10dc8204b3ceec3ab016c0dc Binary files /dev/null and b/mmpretrain/datasets/__pycache__/inshop.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/mnist.cpython-38.pyc b/mmpretrain/datasets/__pycache__/mnist.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..755bc57048b50ce4424689b7fe1e12ef82b75777 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/mnist.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/multi_label.cpython-38.pyc b/mmpretrain/datasets/__pycache__/multi_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a986a8faa7d92b4cfe6319ab56c25d739274ec8c Binary files /dev/null and b/mmpretrain/datasets/__pycache__/multi_label.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/multi_task.cpython-38.pyc b/mmpretrain/datasets/__pycache__/multi_task.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97af6d4ca34c7850590483aafd73bdce8879a64b Binary files /dev/null and b/mmpretrain/datasets/__pycache__/multi_task.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/nlvr2.cpython-38.pyc b/mmpretrain/datasets/__pycache__/nlvr2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667bbe24e5ca3b43e4a6f760eaa898cf4a0114ae Binary files /dev/null and b/mmpretrain/datasets/__pycache__/nlvr2.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/nocaps.cpython-38.pyc b/mmpretrain/datasets/__pycache__/nocaps.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d87383515aa74ab9076ce5780b82e8340a55c40 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/nocaps.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/ocr_vqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/ocr_vqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84dc2861118e17089eeefec33001d5ccf2cad2e8 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/ocr_vqa.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/oxfordiiitpet.cpython-38.pyc b/mmpretrain/datasets/__pycache__/oxfordiiitpet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da9259d9336703e576999ca00473aa03897f986 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/oxfordiiitpet.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/places205.cpython-38.pyc b/mmpretrain/datasets/__pycache__/places205.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed784a999d2ad479f2faf17dbeb9a834bb041ab Binary files /dev/null and b/mmpretrain/datasets/__pycache__/places205.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/refcoco.cpython-38.pyc b/mmpretrain/datasets/__pycache__/refcoco.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4e4425caa2e0485304f78406601159b32a0cd3d Binary files /dev/null and b/mmpretrain/datasets/__pycache__/refcoco.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/scienceqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/scienceqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af13ffd3dcce838a11a86889345e3cd61b74939 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/scienceqa.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/stanfordcars.cpython-38.pyc b/mmpretrain/datasets/__pycache__/stanfordcars.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d58b147212c13fc66e8de94cf160de41c39fefdc Binary files /dev/null and b/mmpretrain/datasets/__pycache__/stanfordcars.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/sun397.cpython-38.pyc b/mmpretrain/datasets/__pycache__/sun397.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efbc882a29c11024373d72f520f1c151ddcf9cb3 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/sun397.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/textvqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/textvqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46679ddf891e11c4054ddeb4ed83b16a20392729 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/textvqa.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/utils.cpython-38.pyc b/mmpretrain/datasets/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a2ce2a1cbae66c671f9e1f43dbdedd64b40e84d Binary files /dev/null and b/mmpretrain/datasets/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/visual_genome.cpython-38.pyc b/mmpretrain/datasets/__pycache__/visual_genome.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3323fa33ff2add992f07c7b00bd0744b9f12032 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/visual_genome.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/vizwiz.cpython-38.pyc b/mmpretrain/datasets/__pycache__/vizwiz.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8dd8cd38ca08d736e3dd293f78c22caf63a882 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/vizwiz.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/voc.cpython-38.pyc b/mmpretrain/datasets/__pycache__/voc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70ab39b8464c043699f1936c5edb96cbdda946c0 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/voc.cpython-38.pyc differ diff --git a/mmpretrain/datasets/__pycache__/vsr.cpython-38.pyc b/mmpretrain/datasets/__pycache__/vsr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a4eb2afdd88ae66f6bd59f35d79701d49790735 Binary files /dev/null and b/mmpretrain/datasets/__pycache__/vsr.cpython-38.pyc differ diff --git a/mmpretrain/datasets/base_dataset.py b/mmpretrain/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dffdf04772163b5fa55afabc8e15ac8c118aadd2 --- /dev/null +++ b/mmpretrain/datasets/base_dataset.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from os import PathLike +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset as _BaseDataset + +from mmpretrain.registry import DATASETS, TRANSFORMS + + +def expanduser(path): + """Expand ~ and ~user constructions. + + If user or $HOME is unknown, do nothing. + """ + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +@DATASETS.register_module() +class BaseDataset(_BaseDataset): + """Base dataset for image classification task. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + .. _OpenMMLab 2.0 style annotation format: + https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md + + Comparing with the :class:`mmengine.BaseDataset`, this class implemented + several useful methods. + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None, which means using all ``data_infos``. + serialize_data (bool): Whether to hold memory using serialized objects, + when enabled, data loader workers can use shared RAM from master + process instead of making a copy. Defaults to True. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + test_mode (bool, optional): ``test_mode=True`` means in test phase, + an error will be raised when getting an item fails, ``test_mode=False`` + means in training phase, another item will be returned randomly. + Defaults to False. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + max_refetch (int): If ``Basedataset.prepare_data`` get a None img. + The maximum extra number of cycles to get a valid image. + Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ # noqa: E501 + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: str = '', + data_prefix: Union[str, dict] = '', + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: Sequence = (), + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + classes: Union[str, Sequence[str], None] = None): + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + metainfo = self._compat_classes(metainfo, classes) + + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=transforms, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @property + def img_prefix(self): + """The prefix of images.""" + return self.data_prefix['img_path'] + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def get_gt_labels(self): + """Get all ground-truth labels (categories). + + Returns: + np.ndarray: categories for all images. + """ + + gt_labels = np.array( + [self.get_data_info(i)['gt_label'] for i in range(len(self))]) + return gt_labels + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category id by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image category of specified index. + """ + + return [int(self.get_data_info(idx)['gt_label'])] + + def _compat_classes(self, metainfo, classes): + """Merge the old style ``classes`` arguments to ``metainfo``.""" + if isinstance(classes, str): + # take it as a file path + class_names = mmengine.list_from_file(expanduser(classes)) + elif isinstance(classes, (tuple, list)): + class_names = classes + elif classes is not None: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if metainfo is None: + metainfo = {} + + if classes is not None: + metainfo = {'classes': tuple(class_names), **metainfo} + + return metainfo + + def full_init(self): + """Load annotation file and set ``BaseDataset._fully_initialized`` to + True.""" + super().full_init() + + # To support the standard OpenMMLab 2.0 annotation format. Generate + # metainfo in internal format from standard metainfo format. + if 'categories' in self._metainfo and 'classes' not in self._metainfo: + categories = sorted( + self._metainfo['categories'], key=lambda x: x['id']) + self._metainfo['classes'] = tuple( + [cat['category_name'] for cat in categories]) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + + body.extend(self.extra_repr()) + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [] + body.append(f'Annotation file: \t{self.ann_file}') + body.append(f'Prefix of images: \t{self.img_prefix}') + return body diff --git a/mmpretrain/datasets/builder.py b/mmpretrain/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa3872fe9931a4946368f07dfc5f5913a3e1f9f --- /dev/null +++ b/mmpretrain/datasets/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.registry import DATASETS + + +def build_dataset(cfg): + """Build dataset. + + Examples: + >>> from mmpretrain.datasets import build_dataset + >>> mnist_train = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False)) + >>> print(mnist_train) + Dataset MNIST + Number of samples: 60000 + Number of categories: 10 + Prefix of data: data/mnist/ + >>> mnist_test = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True)) + >>> print(mnist_test) + Dataset MNIST + Number of samples: 10000 + Number of categories: 10 + Prefix of data: data/mnist/ + """ + return DATASETS.build(cfg) diff --git a/mmpretrain/datasets/caltech101.py b/mmpretrain/datasets/caltech101.py new file mode 100644 index 0000000000000000000000000000000000000000..71e5de85ff3bbf73c387a071f47113b46be36e2a --- /dev/null +++ b/mmpretrain/datasets/caltech101.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CALTECH101_CATEGORIES + + +@DATASETS.register_module() +class Caltech101(BaseDataset): + """The Caltech101 Dataset. + + Support the `Caltech101 `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Caltech101 dataset directory: :: + + caltech-101 + ├── 101_ObjectCategories + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── Annotations + │ ├── class_x + │ │ ├── xx1.mat + │ │ └── ... + │ └── ... + ├── meta + │ ├── train.txt + │ └── test.txt + └── .... + + Please note that since there is no official splitting for training and + test set, you can use the train.txt and text.txt provided by us or + create your own annotation files. Here is the download + `link `_ + for the annotations. + + Args: + data_root (str): The root directory for the Caltech101 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import Caltech101 + >>> train_dataset = Caltech101(data_root='data/caltech-101', split='train') + >>> train_dataset + Dataset Caltech101 + Number of samples: 3060 + Number of categories: 102 + Root of dataset: data/caltech-101 + >>> test_dataset = Caltech101(data_root='data/caltech-101', split='test') + >>> test_dataset + Dataset Caltech101 + Number of samples: 6728 + Number of categories: 102 + Root of dataset: data/caltech-101 + """ # noqa: E501 + + METAINFO = {'classes': CALTECH101_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + + if split == 'train': + ann_file = self.backend.join_path('meta', 'train.txt') + else: + ann_file = self.backend.join_path('meta', 'test.txt') + + data_prefix = '101_ObjectCategories' + test_mode = split == 'test' + + super(Caltech101, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + + for pair in pairs: + path, gt_label = pair.split() + img_path = self.backend.join_path(self.img_prefix, path) + info = dict(img_path=img_path, gt_label=int(gt_label)) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..011ee5c1609ee01614c485abfa69cf0d4fc35417 --- /dev/null +++ b/mmpretrain/datasets/categories.py @@ -0,0 +1,1440 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Pre-defined categories names of various datasets. + +VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', + 'sofa', 'train', 'tvmonitor') + +CUB_CATEGORIES = ( + 'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross', + 'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet', 'Parakeet_Auklet', + 'Rhinoceros_Auklet', 'Brewer_Blackbird', 'Red_winged_Blackbird', + 'Rusty_Blackbird', 'Yellow_headed_Blackbird', 'Bobolink', 'Indigo_Bunting', + 'Lazuli_Bunting', 'Painted_Bunting', 'Cardinal', 'Spotted_Catbird', + 'Gray_Catbird', 'Yellow_breasted_Chat', 'Eastern_Towhee', + 'Chuck_will_Widow', 'Brandt_Cormorant', 'Red_faced_Cormorant', + 'Pelagic_Cormorant', 'Bronzed_Cowbird', 'Shiny_Cowbird', 'Brown_Creeper', + 'American_Crow', 'Fish_Crow', 'Black_billed_Cuckoo', 'Mangrove_Cuckoo', + 'Yellow_billed_Cuckoo', 'Gray_crowned_Rosy_Finch', 'Purple_Finch', + 'Northern_Flicker', 'Acadian_Flycatcher', 'Great_Crested_Flycatcher', + 'Least_Flycatcher', 'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher', + 'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird', + 'Northern_Fulmar', 'Gadwall', 'American_Goldfinch', 'European_Goldfinch', + 'Boat_tailed_Grackle', 'Eared_Grebe', 'Horned_Grebe', 'Pied_billed_Grebe', + 'Western_Grebe', 'Blue_Grosbeak', 'Evening_Grosbeak', 'Pine_Grosbeak', + 'Rose_breasted_Grosbeak', 'Pigeon_Guillemot', 'California_Gull', + 'Glaucous_winged_Gull', 'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull', + 'Ring_billed_Gull', 'Slaty_backed_Gull', 'Western_Gull', + 'Anna_Hummingbird', 'Ruby_throated_Hummingbird', 'Rufous_Hummingbird', + 'Green_Violetear', 'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay', + 'Florida_Jay', 'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird', + 'Gray_Kingbird', 'Belted_Kingfisher', 'Green_Kingfisher', + 'Pied_Kingfisher', 'Ringed_Kingfisher', 'White_breasted_Kingfisher', + 'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard', + 'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser', + 'Mockingbird', 'Nighthawk', 'Clark_Nutcracker', 'White_breasted_Nuthatch', + 'Baltimore_Oriole', 'Hooded_Oriole', 'Orchard_Oriole', 'Scott_Oriole', + 'Ovenbird', 'Brown_Pelican', 'White_Pelican', 'Western_Wood_Pewee', + 'Sayornis', 'American_Pipit', 'Whip_poor_Will', 'Horned_Puffin', + 'Common_Raven', 'White_necked_Raven', 'American_Redstart', 'Geococcyx', + 'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow', + 'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow', + 'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow', 'Fox_Sparrow', + 'Grasshopper_Sparrow', 'Harris_Sparrow', 'Henslow_Sparrow', + 'Le_Conte_Sparrow', 'Lincoln_Sparrow', 'Nelson_Sharp_tailed_Sparrow', + 'Savannah_Sparrow', 'Seaside_Sparrow', 'Song_Sparrow', 'Tree_Sparrow', + 'Vesper_Sparrow', 'White_crowned_Sparrow', 'White_throated_Sparrow', + 'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow', 'Cliff_Swallow', + 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager', 'Artic_Tern', + 'Black_Tern', 'Caspian_Tern', 'Common_Tern', 'Elegant_Tern', + 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee', 'Brown_Thrasher', + 'Sage_Thrasher', 'Black_capped_Vireo', 'Blue_headed_Vireo', + 'Philadelphia_Vireo', 'Red_eyed_Vireo', 'Warbling_Vireo', + 'White_eyed_Vireo', 'Yellow_throated_Vireo', 'Bay_breasted_Warbler', + 'Black_and_white_Warbler', 'Black_throated_Blue_Warbler', + 'Blue_winged_Warbler', 'Canada_Warbler', 'Cape_May_Warbler', + 'Cerulean_Warbler', 'Chestnut_sided_Warbler', 'Golden_winged_Warbler', + 'Hooded_Warbler', 'Kentucky_Warbler', 'Magnolia_Warbler', + 'Mourning_Warbler', 'Myrtle_Warbler', 'Nashville_Warbler', + 'Orange_crowned_Warbler', 'Palm_Warbler', 'Pine_Warbler', + 'Prairie_Warbler', 'Prothonotary_Warbler', 'Swainson_Warbler', + 'Tennessee_Warbler', 'Wilson_Warbler', 'Worm_eating_Warbler', + 'Yellow_Warbler', 'Northern_Waterthrush', 'Louisiana_Waterthrush', + 'Bohemian_Waxwing', 'Cedar_Waxwing', 'American_Three_toed_Woodpecker', + 'Pileated_Woodpecker', 'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker', + 'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren', 'Cactus_Wren', + 'Carolina_Wren', 'House_Wren', 'Marsh_Wren', 'Rock_Wren', 'Winter_Wren', + 'Common_Yellowthroat') + +IMAGENET_CATEGORIES = ( + 'tench, Tinca tinca', + 'goldfish, Carassius auratus', + 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501 + 'tiger shark, Galeocerdo cuvieri', + 'hammerhead, hammerhead shark', + 'electric ray, crampfish, numbfish, torpedo', + 'stingray', + 'cock', + 'hen', + 'ostrich, Struthio camelus', + 'brambling, Fringilla montifringilla', + 'goldfinch, Carduelis carduelis', + 'house finch, linnet, Carpodacus mexicanus', + 'junco, snowbird', + 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 'robin, American robin, Turdus migratorius', + 'bulbul', + 'jay', + 'magpie', + 'chickadee', + 'water ouzel, dipper', + 'kite', + 'bald eagle, American eagle, Haliaeetus leucocephalus', + 'vulture', + 'great grey owl, great gray owl, Strix nebulosa', + 'European fire salamander, Salamandra salamandra', + 'common newt, Triturus vulgaris', + 'eft', + 'spotted salamander, Ambystoma maculatum', + 'axolotl, mud puppy, Ambystoma mexicanum', + 'bullfrog, Rana catesbeiana', + 'tree frog, tree-frog', + 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 'loggerhead, loggerhead turtle, Caretta caretta', + 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501 + 'mud turtle', + 'terrapin', + 'box turtle, box tortoise', + 'banded gecko', + 'common iguana, iguana, Iguana iguana', + 'American chameleon, anole, Anolis carolinensis', + 'whiptail, whiptail lizard', + 'agama', + 'frilled lizard, Chlamydosaurus kingi', + 'alligator lizard', + 'Gila monster, Heloderma suspectum', + 'green lizard, Lacerta viridis', + 'African chameleon, Chamaeleo chamaeleon', + 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501 + 'African crocodile, Nile crocodile, Crocodylus niloticus', + 'American alligator, Alligator mississipiensis', + 'triceratops', + 'thunder snake, worm snake, Carphophis amoenus', + 'ringneck snake, ring-necked snake, ring snake', + 'hognose snake, puff adder, sand viper', + 'green snake, grass snake', + 'king snake, kingsnake', + 'garter snake, grass snake', + 'water snake', + 'vine snake', + 'night snake, Hypsiglena torquata', + 'boa constrictor, Constrictor constrictor', + 'rock python, rock snake, Python sebae', + 'Indian cobra, Naja naja', + 'green mamba', + 'sea snake', + 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 'sidewinder, horned rattlesnake, Crotalus cerastes', + 'trilobite', + 'harvestman, daddy longlegs, Phalangium opilio', + 'scorpion', + 'black and gold garden spider, Argiope aurantia', + 'barn spider, Araneus cavaticus', + 'garden spider, Aranea diademata', + 'black widow, Latrodectus mactans', + 'tarantula', + 'wolf spider, hunting spider', + 'tick', + 'centipede', + 'black grouse', + 'ptarmigan', + 'ruffed grouse, partridge, Bonasa umbellus', + 'prairie chicken, prairie grouse, prairie fowl', + 'peacock', + 'quail', + 'partridge', + 'African grey, African gray, Psittacus erithacus', + 'macaw', + 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 'lorikeet', + 'coucal', + 'bee eater', + 'hornbill', + 'hummingbird', + 'jacamar', + 'toucan', + 'drake', + 'red-breasted merganser, Mergus serrator', + 'goose', + 'black swan, Cygnus atratus', + 'tusker', + 'echidna, spiny anteater, anteater', + 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501 + 'wallaby, brush kangaroo', + 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501 + 'wombat', + 'jellyfish', + 'sea anemone, anemone', + 'brain coral', + 'flatworm, platyhelminth', + 'nematode, nematode worm, roundworm', + 'conch', + 'snail', + 'slug', + 'sea slug, nudibranch', + 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 'chambered nautilus, pearly nautilus, nautilus', + 'Dungeness crab, Cancer magister', + 'rock crab, Cancer irroratus', + 'fiddler crab', + 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501 + 'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501 + 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501 + 'crayfish, crawfish, crawdad, crawdaddy', + 'hermit crab', + 'isopod', + 'white stork, Ciconia ciconia', + 'black stork, Ciconia nigra', + 'spoonbill', + 'flamingo', + 'little blue heron, Egretta caerulea', + 'American egret, great white heron, Egretta albus', + 'bittern', + 'crane', + 'limpkin, Aramus pictus', + 'European gallinule, Porphyrio porphyrio', + 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 'bustard', + 'ruddy turnstone, Arenaria interpres', + 'red-backed sandpiper, dunlin, Erolia alpina', + 'redshank, Tringa totanus', + 'dowitcher', + 'oystercatcher, oyster catcher', + 'pelican', + 'king penguin, Aptenodytes patagonica', + 'albatross, mollymawk', + 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501 + 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 'dugong, Dugong dugon', + 'sea lion', + 'Chihuahua', + 'Japanese spaniel', + 'Maltese dog, Maltese terrier, Maltese', + 'Pekinese, Pekingese, Peke', + 'Shih-Tzu', + 'Blenheim spaniel', + 'papillon', + 'toy terrier', + 'Rhodesian ridgeback', + 'Afghan hound, Afghan', + 'basset, basset hound', + 'beagle', + 'bloodhound, sleuthhound', + 'bluetick', + 'black-and-tan coonhound', + 'Walker hound, Walker foxhound', + 'English foxhound', + 'redbone', + 'borzoi, Russian wolfhound', + 'Irish wolfhound', + 'Italian greyhound', + 'whippet', + 'Ibizan hound, Ibizan Podenco', + 'Norwegian elkhound, elkhound', + 'otterhound, otter hound', + 'Saluki, gazelle hound', + 'Scottish deerhound, deerhound', + 'Weimaraner', + 'Staffordshire bullterrier, Staffordshire bull terrier', + 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501 + 'Bedlington terrier', + 'Border terrier', + 'Kerry blue terrier', + 'Irish terrier', + 'Norfolk terrier', + 'Norwich terrier', + 'Yorkshire terrier', + 'wire-haired fox terrier', + 'Lakeland terrier', + 'Sealyham terrier, Sealyham', + 'Airedale, Airedale terrier', + 'cairn, cairn terrier', + 'Australian terrier', + 'Dandie Dinmont, Dandie Dinmont terrier', + 'Boston bull, Boston terrier', + 'miniature schnauzer', + 'giant schnauzer', + 'standard schnauzer', + 'Scotch terrier, Scottish terrier, Scottie', + 'Tibetan terrier, chrysanthemum dog', + 'silky terrier, Sydney silky', + 'soft-coated wheaten terrier', + 'West Highland white terrier', + 'Lhasa, Lhasa apso', + 'flat-coated retriever', + 'curly-coated retriever', + 'golden retriever', + 'Labrador retriever', + 'Chesapeake Bay retriever', + 'German short-haired pointer', + 'vizsla, Hungarian pointer', + 'English setter', + 'Irish setter, red setter', + 'Gordon setter', + 'Brittany spaniel', + 'clumber, clumber spaniel', + 'English springer, English springer spaniel', + 'Welsh springer spaniel', + 'cocker spaniel, English cocker spaniel, cocker', + 'Sussex spaniel', + 'Irish water spaniel', + 'kuvasz', + 'schipperke', + 'groenendael', + 'malinois', + 'briard', + 'kelpie', + 'komondor', + 'Old English sheepdog, bobtail', + 'Shetland sheepdog, Shetland sheep dog, Shetland', + 'collie', + 'Border collie', + 'Bouvier des Flandres, Bouviers des Flandres', + 'Rottweiler', + 'German shepherd, German shepherd dog, German police dog, alsatian', + 'Doberman, Doberman pinscher', + 'miniature pinscher', + 'Greater Swiss Mountain dog', + 'Bernese mountain dog', + 'Appenzeller', + 'EntleBucher', + 'boxer', + 'bull mastiff', + 'Tibetan mastiff', + 'French bulldog', + 'Great Dane', + 'Saint Bernard, St Bernard', + 'Eskimo dog, husky', + 'malamute, malemute, Alaskan malamute', + 'Siberian husky', + 'dalmatian, coach dog, carriage dog', + 'affenpinscher, monkey pinscher, monkey dog', + 'basenji', + 'pug, pug-dog', + 'Leonberg', + 'Newfoundland, Newfoundland dog', + 'Great Pyrenees', + 'Samoyed, Samoyede', + 'Pomeranian', + 'chow, chow chow', + 'keeshond', + 'Brabancon griffon', + 'Pembroke, Pembroke Welsh corgi', + 'Cardigan, Cardigan Welsh corgi', + 'toy poodle', + 'miniature poodle', + 'standard poodle', + 'Mexican hairless', + 'timber wolf, grey wolf, gray wolf, Canis lupus', + 'white wolf, Arctic wolf, Canis lupus tundrarum', + 'red wolf, maned wolf, Canis rufus, Canis niger', + 'coyote, prairie wolf, brush wolf, Canis latrans', + 'dingo, warrigal, warragal, Canis dingo', + 'dhole, Cuon alpinus', + 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 'hyena, hyaena', + 'red fox, Vulpes vulpes', + 'kit fox, Vulpes macrotis', + 'Arctic fox, white fox, Alopex lagopus', + 'grey fox, gray fox, Urocyon cinereoargenteus', + 'tabby, tabby cat', + 'tiger cat', + 'Persian cat', + 'Siamese cat, Siamese', + 'Egyptian cat', + 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501 + 'lynx, catamount', + 'leopard, Panthera pardus', + 'snow leopard, ounce, Panthera uncia', + 'jaguar, panther, Panthera onca, Felis onca', + 'lion, king of beasts, Panthera leo', + 'tiger, Panthera tigris', + 'cheetah, chetah, Acinonyx jubatus', + 'brown bear, bruin, Ursus arctos', + 'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501 + 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 'sloth bear, Melursus ursinus, Ursus ursinus', + 'mongoose', + 'meerkat, mierkat', + 'tiger beetle', + 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 'ground beetle, carabid beetle', + 'long-horned beetle, longicorn, longicorn beetle', + 'leaf beetle, chrysomelid', + 'dung beetle', + 'rhinoceros beetle', + 'weevil', + 'fly', + 'bee', + 'ant, emmet, pismire', + 'grasshopper, hopper', + 'cricket', + 'walking stick, walkingstick, stick insect', + 'cockroach, roach', + 'mantis, mantid', + 'cicada, cicala', + 'leafhopper', + 'lacewing, lacewing fly', + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501 + 'damselfly', + 'admiral', + 'ringlet, ringlet butterfly', + 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 'cabbage butterfly', + 'sulphur butterfly, sulfur butterfly', + 'lycaenid, lycaenid butterfly', + 'starfish, sea star', + 'sea urchin', + 'sea cucumber, holothurian', + 'wood rabbit, cottontail, cottontail rabbit', + 'hare', + 'Angora, Angora rabbit', + 'hamster', + 'porcupine, hedgehog', + 'fox squirrel, eastern fox squirrel, Sciurus niger', + 'marmot', + 'beaver', + 'guinea pig, Cavia cobaya', + 'sorrel', + 'zebra', + 'hog, pig, grunter, squealer, Sus scrofa', + 'wild boar, boar, Sus scrofa', + 'warthog', + 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 'ox', + 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 'bison', + 'ram, tup', + 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501 + 'ibex, Capra ibex', + 'hartebeest', + 'impala, Aepyceros melampus', + 'gazelle', + 'Arabian camel, dromedary, Camelus dromedarius', + 'llama', + 'weasel', + 'mink', + 'polecat, fitch, foulmart, foumart, Mustela putorius', + 'black-footed ferret, ferret, Mustela nigripes', + 'otter', + 'skunk, polecat, wood pussy', + 'badger', + 'armadillo', + 'three-toed sloth, ai, Bradypus tridactylus', + 'orangutan, orang, orangutang, Pongo pygmaeus', + 'gorilla, Gorilla gorilla', + 'chimpanzee, chimp, Pan troglodytes', + 'gibbon, Hylobates lar', + 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 'guenon, guenon monkey', + 'patas, hussar monkey, Erythrocebus patas', + 'baboon', + 'macaque', + 'langur', + 'colobus, colobus monkey', + 'proboscis monkey, Nasalis larvatus', + 'marmoset', + 'capuchin, ringtail, Cebus capucinus', + 'howler monkey, howler', + 'titi, titi monkey', + 'spider monkey, Ateles geoffroyi', + 'squirrel monkey, Saimiri sciureus', + 'Madagascar cat, ring-tailed lemur, Lemur catta', + 'indri, indris, Indri indri, Indri brevicaudatus', + 'Indian elephant, Elephas maximus', + 'African elephant, Loxodonta africana', + 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 'barracouta, snoek', + 'eel', + 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501 + 'rock beauty, Holocanthus tricolor', + 'anemone fish', + 'sturgeon', + 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 'lionfish', + 'puffer, pufferfish, blowfish, globefish', + 'abacus', + 'abaya', + "academic gown, academic robe, judge's robe", + 'accordion, piano accordion, squeeze box', + 'acoustic guitar', + 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 'airliner', + 'airship, dirigible', + 'altar', + 'ambulance', + 'amphibian, amphibious vehicle', + 'analog clock', + 'apiary, bee house', + 'apron', + 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501 + 'assault rifle, assault gun', + 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 'bakery, bakeshop, bakehouse', + 'balance beam, beam', + 'balloon', + 'ballpoint, ballpoint pen, ballpen, Biro', + 'Band Aid', + 'banjo', + 'bannister, banister, balustrade, balusters, handrail', + 'barbell', + 'barber chair', + 'barbershop', + 'barn', + 'barometer', + 'barrel, cask', + 'barrow, garden cart, lawn cart, wheelbarrow', + 'baseball', + 'basketball', + 'bassinet', + 'bassoon', + 'bathing cap, swimming cap', + 'bath towel', + 'bathtub, bathing tub, bath, tub', + 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501 + 'beacon, lighthouse, beacon light, pharos', + 'beaker', + 'bearskin, busby, shako', + 'beer bottle', + 'beer glass', + 'bell cote, bell cot', + 'bib', + 'bicycle-built-for-two, tandem bicycle, tandem', + 'bikini, two-piece', + 'binder, ring-binder', + 'binoculars, field glasses, opera glasses', + 'birdhouse', + 'boathouse', + 'bobsled, bobsleigh, bob', + 'bolo tie, bolo, bola tie, bola', + 'bonnet, poke bonnet', + 'bookcase', + 'bookshop, bookstore, bookstall', + 'bottlecap', + 'bow', + 'bow tie, bow-tie, bowtie', + 'brass, memorial tablet, plaque', + 'brassiere, bra, bandeau', + 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 'breastplate, aegis, egis', + 'broom', + 'bucket, pail', + 'buckle', + 'bulletproof vest', + 'bullet train, bullet', + 'butcher shop, meat market', + 'cab, hack, taxi, taxicab', + 'caldron, cauldron', + 'candle, taper, wax light', + 'cannon', + 'canoe', + 'can opener, tin opener', + 'cardigan', + 'car mirror', + 'carousel, carrousel, merry-go-round, roundabout, whirligig', + "carpenter's kit, tool kit", + 'carton', + 'car wheel', + 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501 + 'cassette', + 'cassette player', + 'castle', + 'catamaran', + 'CD player', + 'cello, violoncello', + 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 'chain', + 'chainlink fence', + 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501 + 'chain saw, chainsaw', + 'chest', + 'chiffonier, commode', + 'chime, bell, gong', + 'china cabinet, china closet', + 'Christmas stocking', + 'church, church building', + 'cinema, movie theater, movie theatre, movie house, picture palace', + 'cleaver, meat cleaver, chopper', + 'cliff dwelling', + 'cloak', + 'clog, geta, patten, sabot', + 'cocktail shaker', + 'coffee mug', + 'coffeepot', + 'coil, spiral, volute, whorl, helix', + 'combination lock', + 'computer keyboard, keypad', + 'confectionery, confectionary, candy store', + 'container ship, containership, container vessel', + 'convertible', + 'corkscrew, bottle screw', + 'cornet, horn, trumpet, trump', + 'cowboy boot', + 'cowboy hat, ten-gallon hat', + 'cradle', + 'crane', + 'crash helmet', + 'crate', + 'crib, cot', + 'Crock Pot', + 'croquet ball', + 'crutch', + 'cuirass', + 'dam, dike, dyke', + 'desk', + 'desktop computer', + 'dial telephone, dial phone', + 'diaper, nappy, napkin', + 'digital clock', + 'digital watch', + 'dining table, board', + 'dishrag, dishcloth', + 'dishwasher, dish washer, dishwashing machine', + 'disk brake, disc brake', + 'dock, dockage, docking facility', + 'dogsled, dog sled, dog sleigh', + 'dome', + 'doormat, welcome mat', + 'drilling platform, offshore rig', + 'drum, membranophone, tympan', + 'drumstick', + 'dumbbell', + 'Dutch oven', + 'electric fan, blower', + 'electric guitar', + 'electric locomotive', + 'entertainment center', + 'envelope', + 'espresso maker', + 'face powder', + 'feather boa, boa', + 'file, file cabinet, filing cabinet', + 'fireboat', + 'fire engine, fire truck', + 'fire screen, fireguard', + 'flagpole, flagstaff', + 'flute, transverse flute', + 'folding chair', + 'football helmet', + 'forklift', + 'fountain', + 'fountain pen', + 'four-poster', + 'freight car', + 'French horn, horn', + 'frying pan, frypan, skillet', + 'fur coat', + 'garbage truck, dustcart', + 'gasmask, respirator, gas helmet', + 'gas pump, gasoline pump, petrol pump, island dispenser', + 'goblet', + 'go-kart', + 'golf ball', + 'golfcart, golf cart', + 'gondola', + 'gong, tam-tam', + 'gown', + 'grand piano, grand', + 'greenhouse, nursery, glasshouse', + 'grille, radiator grille', + 'grocery store, grocery, food market, market', + 'guillotine', + 'hair slide', + 'hair spray', + 'half track', + 'hammer', + 'hamper', + 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 'hand-held computer, hand-held microcomputer', + 'handkerchief, hankie, hanky, hankey', + 'hard disc, hard disk, fixed disk', + 'harmonica, mouth organ, harp, mouth harp', + 'harp', + 'harvester, reaper', + 'hatchet', + 'holster', + 'home theater, home theatre', + 'honeycomb', + 'hook, claw', + 'hoopskirt, crinoline', + 'horizontal bar, high bar', + 'horse cart, horse-cart', + 'hourglass', + 'iPod', + 'iron, smoothing iron', + "jack-o'-lantern", + 'jean, blue jean, denim', + 'jeep, landrover', + 'jersey, T-shirt, tee shirt', + 'jigsaw puzzle', + 'jinrikisha, ricksha, rickshaw', + 'joystick', + 'kimono', + 'knee pad', + 'knot', + 'lab coat, laboratory coat', + 'ladle', + 'lampshade, lamp shade', + 'laptop, laptop computer', + 'lawn mower, mower', + 'lens cap, lens cover', + 'letter opener, paper knife, paperknife', + 'library', + 'lifeboat', + 'lighter, light, igniter, ignitor', + 'limousine, limo', + 'liner, ocean liner', + 'lipstick, lip rouge', + 'Loafer', + 'lotion', + 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501 + "loupe, jeweler's loupe", + 'lumbermill, sawmill', + 'magnetic compass', + 'mailbag, postbag', + 'mailbox, letter box', + 'maillot', + 'maillot, tank suit', + 'manhole cover', + 'maraca', + 'marimba, xylophone', + 'mask', + 'matchstick', + 'maypole', + 'maze, labyrinth', + 'measuring cup', + 'medicine chest, medicine cabinet', + 'megalith, megalithic structure', + 'microphone, mike', + 'microwave, microwave oven', + 'military uniform', + 'milk can', + 'minibus', + 'miniskirt, mini', + 'minivan', + 'missile', + 'mitten', + 'mixing bowl', + 'mobile home, manufactured home', + 'Model T', + 'modem', + 'monastery', + 'monitor', + 'moped', + 'mortar', + 'mortarboard', + 'mosque', + 'mosquito net', + 'motor scooter, scooter', + 'mountain bike, all-terrain bike, off-roader', + 'mountain tent', + 'mouse, computer mouse', + 'mousetrap', + 'moving van', + 'muzzle', + 'nail', + 'neck brace', + 'necklace', + 'nipple', + 'notebook, notebook computer', + 'obelisk', + 'oboe, hautboy, hautbois', + 'ocarina, sweet potato', + 'odometer, hodometer, mileometer, milometer', + 'oil filter', + 'organ, pipe organ', + 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 'overskirt', + 'oxcart', + 'oxygen mask', + 'packet', + 'paddle, boat paddle', + 'paddlewheel, paddle wheel', + 'padlock', + 'paintbrush', + "pajama, pyjama, pj's, jammies", + 'palace', + 'panpipe, pandean pipe, syrinx', + 'paper towel', + 'parachute, chute', + 'parallel bars, bars', + 'park bench', + 'parking meter', + 'passenger car, coach, carriage', + 'patio, terrace', + 'pay-phone, pay-station', + 'pedestal, plinth, footstall', + 'pencil box, pencil case', + 'pencil sharpener', + 'perfume, essence', + 'Petri dish', + 'photocopier', + 'pick, plectrum, plectron', + 'pickelhaube', + 'picket fence, paling', + 'pickup, pickup truck', + 'pier', + 'piggy bank, penny bank', + 'pill bottle', + 'pillow', + 'ping-pong ball', + 'pinwheel', + 'pirate, pirate ship', + 'pitcher, ewer', + "plane, carpenter's plane, woodworking plane", + 'planetarium', + 'plastic bag', + 'plate rack', + 'plow, plough', + "plunger, plumber's helper", + 'Polaroid camera, Polaroid Land camera', + 'pole', + 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501 + 'poncho', + 'pool table, billiard table, snooker table', + 'pop bottle, soda bottle', + 'pot, flowerpot', + "potter's wheel", + 'power drill', + 'prayer rug, prayer mat', + 'printer', + 'prison, prison house', + 'projectile, missile', + 'projector', + 'puck, hockey puck', + 'punching bag, punch bag, punching ball, punchball', + 'purse', + 'quill, quill pen', + 'quilt, comforter, comfort, puff', + 'racer, race car, racing car', + 'racket, racquet', + 'radiator', + 'radio, wireless', + 'radio telescope, radio reflector', + 'rain barrel', + 'recreational vehicle, RV, R.V.', + 'reel', + 'reflex camera', + 'refrigerator, icebox', + 'remote control, remote', + 'restaurant, eating house, eating place, eatery', + 'revolver, six-gun, six-shooter', + 'rifle', + 'rocking chair, rocker', + 'rotisserie', + 'rubber eraser, rubber, pencil eraser', + 'rugby ball', + 'rule, ruler', + 'running shoe', + 'safe', + 'safety pin', + 'saltshaker, salt shaker', + 'sandal', + 'sarong', + 'sax, saxophone', + 'scabbard', + 'scale, weighing machine', + 'school bus', + 'schooner', + 'scoreboard', + 'screen, CRT screen', + 'screw', + 'screwdriver', + 'seat belt, seatbelt', + 'sewing machine', + 'shield, buckler', + 'shoe shop, shoe-shop, shoe store', + 'shoji', + 'shopping basket', + 'shopping cart', + 'shovel', + 'shower cap', + 'shower curtain', + 'ski', + 'ski mask', + 'sleeping bag', + 'slide rule, slipstick', + 'sliding door', + 'slot, one-armed bandit', + 'snorkel', + 'snowmobile', + 'snowplow, snowplough', + 'soap dispenser', + 'soccer ball', + 'sock', + 'solar dish, solar collector, solar furnace', + 'sombrero', + 'soup bowl', + 'space bar', + 'space heater', + 'space shuttle', + 'spatula', + 'speedboat', + "spider web, spider's web", + 'spindle', + 'sports car, sport car', + 'spotlight, spot', + 'stage', + 'steam locomotive', + 'steel arch bridge', + 'steel drum', + 'stethoscope', + 'stole', + 'stone wall', + 'stopwatch, stop watch', + 'stove', + 'strainer', + 'streetcar, tram, tramcar, trolley, trolley car', + 'stretcher', + 'studio couch, day bed', + 'stupa, tope', + 'submarine, pigboat, sub, U-boat', + 'suit, suit of clothes', + 'sundial', + 'sunglass', + 'sunglasses, dark glasses, shades', + 'sunscreen, sunblock, sun blocker', + 'suspension bridge', + 'swab, swob, mop', + 'sweatshirt', + 'swimming trunks, bathing trunks', + 'swing', + 'switch, electric switch, electrical switch', + 'syringe', + 'table lamp', + 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 'tape player', + 'teapot', + 'teddy, teddy bear', + 'television, television system', + 'tennis ball', + 'thatch, thatched roof', + 'theater curtain, theatre curtain', + 'thimble', + 'thresher, thrasher, threshing machine', + 'throne', + 'tile roof', + 'toaster', + 'tobacco shop, tobacconist shop, tobacconist', + 'toilet seat', + 'torch', + 'totem pole', + 'tow truck, tow car, wrecker', + 'toyshop', + 'tractor', + 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501 + 'tray', + 'trench coat', + 'tricycle, trike, velocipede', + 'trimaran', + 'tripod', + 'triumphal arch', + 'trolleybus, trolley coach, trackless trolley', + 'trombone', + 'tub, vat', + 'turnstile', + 'typewriter keyboard', + 'umbrella', + 'unicycle, monocycle', + 'upright, upright piano', + 'vacuum, vacuum cleaner', + 'vase', + 'vault', + 'velvet', + 'vending machine', + 'vestment', + 'viaduct', + 'violin, fiddle', + 'volleyball', + 'waffle iron', + 'wall clock', + 'wallet, billfold, notecase, pocketbook', + 'wardrobe, closet, press', + 'warplane, military plane', + 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 'washer, automatic washer, washing machine', + 'water bottle', + 'water jug', + 'water tower', + 'whiskey jug', + 'whistle', + 'wig', + 'window screen', + 'window shade', + 'Windsor tie', + 'wine bottle', + 'wing', + 'wok', + 'wooden spoon', + 'wool, woolen, woollen', + 'worm fence, snake fence, snake-rail fence, Virginia fence', + 'wreck', + 'yawl', + 'yurt', + 'web site, website, internet site, site', + 'comic book', + 'crossword puzzle, crossword', + 'street sign', + 'traffic light, traffic signal, stoplight', + 'book jacket, dust cover, dust jacket, dust wrapper', + 'menu', + 'plate', + 'guacamole', + 'consomme', + 'hot pot, hotpot', + 'trifle', + 'ice cream, icecream', + 'ice lolly, lolly, lollipop, popsicle', + 'French loaf', + 'bagel, beigel', + 'pretzel', + 'cheeseburger', + 'hotdog, hot dog, red hot', + 'mashed potato', + 'head cabbage', + 'broccoli', + 'cauliflower', + 'zucchini, courgette', + 'spaghetti squash', + 'acorn squash', + 'butternut squash', + 'cucumber, cuke', + 'artichoke, globe artichoke', + 'bell pepper', + 'cardoon', + 'mushroom', + 'Granny Smith', + 'strawberry', + 'orange', + 'lemon', + 'fig', + 'pineapple, ananas', + 'banana', + 'jackfruit, jak, jack', + 'custard apple', + 'pomegranate', + 'hay', + 'carbonara', + 'chocolate sauce, chocolate syrup', + 'dough', + 'meat loaf, meatloaf', + 'pizza, pizza pie', + 'potpie', + 'burrito', + 'red wine', + 'espresso', + 'cup', + 'eggnog', + 'alp', + 'bubble', + 'cliff, drop, drop-off', + 'coral reef', + 'geyser', + 'lakeside, lakeshore', + 'promontory, headland, head, foreland', + 'sandbar, sand bar', + 'seashore, coast, seacoast, sea-coast', + 'valley, vale', + 'volcano', + 'ballplayer, baseball player', + 'groom, bridegroom', + 'scuba diver', + 'rapeseed', + 'daisy', + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501 + 'corn', + 'acorn', + 'hip, rose hip, rosehip', + 'buckeye, horse chestnut, conker', + 'coral fungus', + 'agaric', + 'gyromitra', + 'stinkhorn, carrion fungus', + 'earthstar', + 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501 + 'bolete', + 'ear, spike, capitulum', + 'toilet tissue, toilet paper, bathroom tissue') + +CIFAR10_CATEGORIES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', + 'frog', 'horse', 'ship', 'truck') + +CIFAR100_CATEGORIES = ( + 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', + 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', + 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', + 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', + 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', + 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', + 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', + 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', + 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', + 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', + 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', + 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', + 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', + 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', + 'woman', 'worm') + +MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', + '9 - nine') + +FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', + 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', + 'Ankle boot') + +PLACES205_CATEGORIES = ( + 'abbey', 'airport_terminal', 'alley', 'amphitheater', 'amusement_park', + 'aquarium', 'aqueduct', 'arch', 'art_gallery', 'art_studio', + 'assembly_line', 'attic', 'auditorium', 'apartment_building/outdoor', + 'badlands', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', + 'baseball_field', 'basement', 'basilica', 'bayou', 'beauty_salon', + 'bedroom', 'boardwalk', 'boat_deck', 'bookstore', 'botanical_garden', + 'bowling_alley', 'boxing_ring', 'bridge', 'building_facade', + 'bus_interior', 'butchers_shop', 'butte', 'bakery/shop', 'cafeteria', + 'campsite', 'candy_store', 'canyon', 'castle', 'cemetery', 'chalet', + 'classroom', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', + 'conference_center', 'conference_room', 'construction_site', 'corn_field', + 'corridor', 'cottage_garden', 'courthouse', 'courtyard', 'creek', + 'crevasse', 'crosswalk', 'cathedral/outdoor', 'church/outdoor', 'dam', + 'dining_room', 'dock', 'dorm_room', 'driveway', 'desert/sand', + 'desert/vegetation', 'dinette/home', 'doorway/outdoor', 'engine_room', + 'excavation', 'fairway', 'fire_escape', 'fire_station', 'food_court', + 'forest_path', 'forest_road', 'formal_garden', 'fountain', + 'field/cultivated', 'field/wild', 'galley', 'game_room', 'garbage_dump', + 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'herb_garden', + 'highway', 'home_office', 'hospital', 'hospital_room', 'hot_spring', + 'hotel_room', 'hotel/outdoor', 'ice_cream_parlor', 'iceberg', 'igloo', + 'islet', 'ice_skating_rink/outdoor', 'inn/outdoor', 'jail_cell', 'kasbah', + 'kindergarden_classroom', 'kitchen', 'kitchenette', 'laundromat', + 'lighthouse', 'living_room', 'lobby', 'locker_room', 'mansion', 'marsh', + 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', + 'mountain_snowy', 'music_studio', 'market/outdoor', 'monastery/outdoor', + 'museum/indoor', 'nursery', 'ocean', 'office', 'office_building', + 'orchard', 'pagoda', 'palace', 'pantry', 'parking_lot', 'parlor', + 'pasture', 'patio', 'pavilion', 'phone_booth', 'picnic_area', 'playground', + 'plaza', 'pond', 'pulpit', 'racecourse', 'raft', 'railroad_track', + 'rainforest', 'reception', 'residential_neighborhood', 'restaurant', + 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'river', + 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'schoolhouse', + 'sea_cliff', 'shed', 'shoe_shop', 'shopfront', 'shower', 'ski_resort', + 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'staircase', + 'supermarket', 'swamp', 'stadium/baseball', 'stadium/football', + 'stage/indoor', 'subway_station/platform', 'swimming_pool/outdoor', + 'television_studio', 'topiary_garden', 'tower', 'train_railway', + 'tree_farm', 'trench', 'temple/east_asia', 'temple/south_asia', + 'track/outdoor', 'train_station/platform', 'underwater/coral_reef', + 'valley', 'vegetable_garden', 'veranda', 'viaduct', 'volcano', + 'waiting_room', 'water_tower', 'watering_hole', 'wheat_field', 'wind_farm', + 'windmill', 'yard') + +OxfordIIITPet_CATEGORIES = ( + 'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', + 'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer', + 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel', + 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', + 'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon', + 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', + 'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier', + 'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier', + 'wheaten_terrier', 'yorkshire_terrier') + +DTD_CATEGORIES = ('banded', 'blotchy', 'braided', 'bubbly', 'bumpy', + 'chequered', 'cobwebbed', 'cracked', 'crosshatched', + 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', + 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', + 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', + 'matted', 'meshed', 'paisley', 'perforated', 'pitted', + 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', + 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', + 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', + 'wrinkled', 'zigzagged') + +FGVCAIRCRAFT_CATEGORIES = ( + '707-320', '727-200', '737-200', '737-300', '737-400', '737-500', + '737-600', '737-700', '737-800', '737-900', '747-100', '747-200', + '747-300', '747-400', '757-200', '757-300', '767-200', '767-300', + '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', + 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500', + 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200', + 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47', + 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', + 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', + 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400', + 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145', + 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18', + 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70', + 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', + 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', + 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', + 'Tu-154', 'Yak-42') + +STANFORDCARS_CATEGORIES = ( + 'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012', + 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012', + 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012', + 'Aston Martin V8 Vantage Convertible 2012', + 'Aston Martin V8 Vantage Coupe 2012', + 'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012', + 'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', + 'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', + 'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011', + 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012', + 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012', + 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012', + 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012', + 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007', + 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012', + 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012', + 'BMW Z4 Convertible 2012', + 'Bentley Continental Supersports Conv. Convertible 2012', + 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011', + 'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007', + 'Bentley Continental Flying Spur Sedan 2007', + 'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009', + 'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012', + 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012', + 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007', + 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012', + 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012', + 'Chevrolet Corvette Ron Fellows Edition Z06 2007', + 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012', + 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007', + 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012', + 'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012', + 'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010', + 'Chevrolet TrailBlazer SS 2009', + 'Chevrolet Silverado 2500HD Regular Cab 2012', + 'Chevrolet Silverado 1500 Classic Extended Cab 2007', + 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007', + 'Chevrolet Malibu Sedan 2007', + 'Chevrolet Silverado 1500 Extended Cab 2012', + 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009', + 'Chrysler Sebring Convertible 2010', + 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010', + 'Chrysler Crossfire Convertible 2008', + 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002', + 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007', + 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010', + 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009', + 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010', + 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008', + 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012', + 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012', + 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998', + 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012', + 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012', + 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012', + 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012', + 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007', + 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012', + 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006', + 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007', + 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012', + 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012', + 'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012', + 'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993', + 'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009', + 'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007', + 'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012', + 'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012', + 'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012', + 'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007', + 'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012', + 'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012', + 'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012', + 'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012', + 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012', + 'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012', + 'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012', + 'Lamborghini Gallardo LP 570-4 Superleggera 2012', + 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012', + 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011', + 'MINI Cooper Roadster Convertible 2012', + 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011', + 'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993', + 'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009', + 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012', + 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012', + 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012', + 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998', + 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012', + 'Ram C/V Cargo Van Minivan 2012', + 'Rolls-Royce Phantom Drophead Coupe Convertible 2012', + 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012', + 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009', + 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007', + 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012', + 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012', + 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012', + 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012', + 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991', + 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012', + 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007', + 'smart fortwo Convertible 2012') + +SUN397_CATEGORIES = ( + 'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', + 'amusement_arcade', 'amusement_park', 'anechoic_chamber', + 'apartment_building_outdoor', 'apse_indoor', 'aquarium', 'aqueduct', + 'arch', 'archive', 'arrival_gate_outdoor', 'art_gallery', 'art_school', + 'art_studio', 'assembly_line', 'athletic_field_outdoor', 'atrium_public', + 'attic', 'auditorium', 'auto_factory', 'badlands', + 'badminton_court_indoor', 'baggage_claim', 'bakery_shop', + 'balcony_exterior', 'balcony_interior', 'ball_pit', 'ballroom', + 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', + 'baseball_field', 'basement', 'basilica', 'basketball_court_outdoor', + 'bathroom', 'batters_box', 'bayou', 'bazaar_indoor', 'bazaar_outdoor', + 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', + 'bistro_indoor', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', + 'booth_indoor', 'botanical_garden', 'bow_window_indoor', + 'bow_window_outdoor', 'bowling_alley', 'boxing_ring', 'brewery_indoor', + 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', + 'butchers_shop', 'butte', 'cabin_outdoor', 'cafeteria', 'campsite', + 'campus', 'canal_natural', 'canal_urban', 'candy_store', 'canyon', + 'car_interior_backseat', 'car_interior_frontseat', 'carrousel', + 'casino_indoor', 'castle', 'catacomb', 'cathedral_indoor', + 'cathedral_outdoor', 'cavern_indoor', 'cemetery', 'chalet', + 'cheese_factory', 'chemistry_lab', 'chicken_coop_indoor', + 'chicken_coop_outdoor', 'childs_room', 'church_indoor', 'church_outdoor', + 'classroom', 'clean_room', 'cliff', 'cloister_indoor', 'closet', + 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', + 'conference_center', 'conference_room', 'construction_site', + 'control_room', 'control_tower_outdoor', 'corn_field', 'corral', + 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', + 'covered_bridge_exterior', 'creek', 'crevasse', 'crosswalk', + 'cubicle_office', 'dam', 'delicatessen', 'dentists_office', 'desert_sand', + 'desert_vegetation', 'diner_indoor', 'diner_outdoor', 'dinette_home', + 'dinette_vehicle', 'dining_car', 'dining_room', 'discotheque', 'dock', + 'doorway_outdoor', 'dorm_room', 'driveway', 'driving_range_outdoor', + 'drugstore', 'electrical_substation', 'elevator_door', 'elevator_interior', + 'elevator_shaft', 'engine_room', 'escalator_indoor', 'excavation', + 'factory_indoor', 'fairway', 'fastfood_restaurant', 'field_cultivated', + 'field_wild', 'fire_escape', 'fire_station', 'firing_range_indoor', + 'fishpond', 'florist_shop_indoor', 'food_court', 'forest_broadleaf', + 'forest_needleleaf', 'forest_path', 'forest_road', 'formal_garden', + 'fountain', 'galley', 'game_room', 'garage_indoor', 'garbage_dump', + 'gas_station', 'gazebo_exterior', 'general_store_indoor', + 'general_store_outdoor', 'gift_shop', 'golf_course', 'greenhouse_indoor', + 'greenhouse_outdoor', 'gymnasium_indoor', 'hangar_indoor', + 'hangar_outdoor', 'harbor', 'hayfield', 'heliport', 'herb_garden', + 'highway', 'hill', 'home_office', 'hospital', 'hospital_room', + 'hot_spring', 'hot_tub_outdoor', 'hotel_outdoor', 'hotel_room', 'house', + 'hunting_lodge_outdoor', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', + 'ice_skating_rink_indoor', 'ice_skating_rink_outdoor', 'iceberg', 'igloo', + 'industrial_area', 'inn_outdoor', 'islet', 'jacuzzi_indoor', 'jail_indoor', + 'jail_cell', 'jewelry_shop', 'kasbah', 'kennel_indoor', 'kennel_outdoor', + 'kindergarden_classroom', 'kitchen', 'kitchenette', 'labyrinth_outdoor', + 'lake_natural', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', + 'library_indoor', 'library_outdoor', 'lido_deck_outdoor', 'lift_bridge', + 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', + 'locker_room', 'mansion', 'manufactured_home', 'market_indoor', + 'market_outdoor', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', + 'moat_water', 'monastery_outdoor', 'mosque_indoor', 'mosque_outdoor', + 'motel', 'mountain', 'mountain_snowy', 'movie_theater_indoor', + 'museum_indoor', 'music_store', 'music_studio', + 'nuclear_power_plant_outdoor', 'nursery', 'oast_house', + 'observatory_outdoor', 'ocean', 'office', 'office_building', + 'oil_refinery_outdoor', 'oilrig', 'operating_room', 'orchard', + 'outhouse_outdoor', 'pagoda', 'palace', 'pantry', 'park', + 'parking_garage_indoor', 'parking_garage_outdoor', 'parking_lot', 'parlor', + 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', + 'physics_laboratory', 'picnic_area', 'pilothouse_indoor', + 'planetarium_outdoor', 'playground', 'playroom', 'plaza', 'podium_indoor', + 'podium_outdoor', 'pond', 'poolroom_establishment', 'poolroom_home', + 'power_plant_outdoor', 'promenade_deck', 'pub_indoor', 'pulpit', + 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', + 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', + 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', + 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', + 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', + 'shed', 'shoe_shop', 'shopfront', 'shopping_mall_indoor', 'shower', + 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', + 'slum', 'snowfield', 'squash_court', 'stable', 'stadium_baseball', + 'stadium_football', 'stage_indoor', 'staircase', 'street', + 'subway_interior', 'subway_station_platform', 'supermarket', 'sushi_bar', + 'swamp', 'swimming_pool_indoor', 'swimming_pool_outdoor', + 'synagogue_indoor', 'synagogue_outdoor', 'television_studio', + 'temple_east_asia', 'temple_south_asia', 'tennis_court_indoor', + 'tennis_court_outdoor', 'tent_outdoor', 'theater_indoor_procenium', + 'theater_indoor_seats', 'thriftshop', 'throne_room', 'ticket_booth', + 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'track_outdoor', + 'train_railway', 'train_station_platform', 'tree_farm', 'tree_house', + 'trench', 'underwater_coral_reef', 'utility_room', 'valley', + 'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office', + 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', + 'volleyball_court_indoor', 'volleyball_court_outdoor', 'waiting_room', + 'warehouse_indoor', 'water_tower', 'waterfall_block', 'waterfall_fan', + 'waterfall_plunge', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', + 'wind_farm', 'windmill', 'wine_cellar_barrel_storage', + 'wine_cellar_bottle_storage', 'wrestling_ring_indoor', 'yard', + 'youth_hostel') + +CALTECH101_CATEGORIES = ( + 'BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes', + 'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver', + 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', + 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', + 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', + 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', + 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', + 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', + 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', + 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', + 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', + 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', + 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', + 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', + 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', + 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', + 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', + 'wrench', 'yin_yang') + +FOOD101_CATEGORIES = ( + 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', + 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', + 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', + 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', + 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', + 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', + 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', + 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', + 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', + 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', + 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', + 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', + 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', + 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', + 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', + 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', + 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', + 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', + 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', + 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', + 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', + 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles') + +CIFAR100_CATEGORIES_CN = ( + '苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩', + '桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云', + '蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩', + '仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树', + '摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树', + '田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海', + '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒', + '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼', + '柳树', '狼', '女人', '蠕虫') diff --git a/mmpretrain/datasets/cifar.py b/mmpretrain/datasets/cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..2a011daee0d74e6b06613106f7587b8ad8a7ed90 --- /dev/null +++ b/mmpretrain/datasets/cifar.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pickle +from typing import List, Optional + +import mmengine.dist as dist +import numpy as np +from mmengine.fileio import (LocalBackend, exists, get, get_file_backend, + join_path) +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES +from .utils import check_md5, download_and_extract_archive + + +@DATASETS.register_module() +class CIFAR10(BaseDataset): + """`CIFAR10 `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py + + Args: + data_root (str): The root directory of the CIFAR Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ # noqa: E501 + + base_folder = 'cifar-10-batches-py' + url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + filename = 'cifar-10-python.tar.gz' + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + METAINFO = {'classes': CIFAR10_CATEGORIES} + + def __init__(self, + data_root: str = '', + split: str = 'train', + metainfo: Optional[dict] = None, + download: bool = True, + data_prefix: str = '', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + if not data_root and not data_prefix: + raise RuntimeError('Please set ``data_root`` to' + 'specify the dataset path') + + self.download = download + super().__init__( + # The CIFAR dataset doesn't need specify annotation file + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=dict(root=data_prefix), + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) + + if dist.is_main_process() and not self._check_integrity(): + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') + + if self.download: + download_and_extract_archive( + self.url, root, filename=self.filename, md5=self.tgz_md5) + else: + raise RuntimeError( + f'Cannot find {self.__class__.__name__} dataset in ' + f"{self.data_prefix['root']}, you can specify " + '`download=True` to download automatically.') + + dist.barrier() + assert self._check_integrity(), \ + 'Download failed or shared storage is unavailable. Please ' \ + f'download the dataset manually through {self.url}.' + + if self.split == 'train': + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + imgs = [] + gt_labels = [] + + # load the picked numpy arrays + for file_name, _ in downloaded_list: + file_path = join_path(root, self.base_folder, file_name) + entry = pickle.loads(get(file_path), encoding='latin1') + imgs.append(entry['data']) + if 'labels' in entry: + gt_labels.extend(entry['labels']) + else: + gt_labels.extend(entry['fine_labels']) + + imgs = np.vstack(imgs).reshape(-1, 3, 32, 32) + imgs = imgs.transpose((0, 2, 3, 1)) # convert to HWC + + if self.CLASSES is None: + # The metainfo in the file has the lowest priority, therefore + # we only need to load it if classes is not specified. + self._load_meta() + + data_list = [] + for img, gt_label in zip(imgs, gt_labels): + info = {'img': img, 'gt_label': int(gt_label)} + data_list.append(info) + return data_list + + def _load_meta(self): + """Load categories information from metafile.""" + root = self.data_prefix['root'] + + path = join_path(root, self.base_folder, self.meta['filename']) + md5 = self.meta.get('md5', None) + if not exists(path) or (md5 is not None and not check_md5(path, md5)): + raise RuntimeError( + 'Dataset metadata file not found or corrupted.' + + ' You can use `download=True` to download it') + data = pickle.loads(get(path), encoding='latin1') + self._metainfo.setdefault('classes', data[self.meta['key']]) + + def _check_integrity(self): + """Check the integrity of data files.""" + root = self.data_prefix['root'] + + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = join_path(root, self.base_folder, filename) + if not exists(fpath): + return False + if md5 is not None and not check_md5(fpath, md5): + return False + return True + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [f"Prefix of data: \t{self.data_prefix['root']}"] + return body + + +@DATASETS.register_module() +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + Args: + data_root (str): The root directory of the CIFAR Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + base_folder = 'cifar-100-python' + url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' + filename = 'cifar-100-python.tar.gz' + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } + METAINFO = {'classes': CIFAR100_CATEGORIES} diff --git a/mmpretrain/datasets/coco_caption.py b/mmpretrain/datasets/coco_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..541cda80398f7fcc7d3304d3d9f43155685ebe57 --- /dev/null +++ b/mmpretrain/datasets/coco_caption.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class COCOCaption(BaseDataset): + """COCO Caption dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``.. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + for ann in annotations: + data_info = { + 'image_id': Path(ann['image']).stem.split('_')[-1], + 'img_path': file_backend.join_path(img_prefix, ann['image']), + 'gt_caption': ann['caption'], + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..60d1586ad8672a4b57fcdc62740b3e08c3e2e20e --- /dev/null +++ b/mmpretrain/datasets/coco_retrieval.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from collections import OrderedDict +from typing import List + +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class COCORetrieval(BaseDataset): + """COCO Retrieval dataset. + + Args: + ann_file (str): Annotation file path. + test_mode (bool): Whether dataset is used for evaluation. This will + decide the annotation format in data list annotations. + Defaults to False. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + # get file backend + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + anno_info = json.load(open(self.ann_file, 'r')) + # mapping img_id to img filename + img_dict = OrderedDict() + for idx, img in enumerate(anno_info['images']): + if img['id'] not in img_dict: + img_rel_path = img['coco_url'].rsplit('/', 2)[-2:] + img_path = file_backend.join_path(img_prefix, *img_rel_path) + + # create new idx for image + img_dict[img['id']] = dict( + ori_id=img['id'], + image_id=idx, # will be used for evaluation + img_path=img_path, + text=[], + gt_text_id=[], + gt_image_id=[], + ) + + train_list = [] + for idx, anno in enumerate(anno_info['annotations']): + anno['text'] = anno.pop('caption') + anno['ori_id'] = anno.pop('id') + anno['text_id'] = idx # will be used for evaluation + # 1. prepare train data list item + train_data = anno.copy() + train_image = img_dict[train_data['image_id']] + train_data['img_path'] = train_image['img_path'] + train_data['image_ori_id'] = train_image['ori_id'] + train_data['image_id'] = train_image['image_id'] + train_data['is_matched'] = True + train_list.append(train_data) + # 2. prepare eval data list item based on img dict + img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id']) + img_dict[anno['image_id']]['text'].append(anno['text']) + img_dict[anno['image_id']]['gt_image_id'].append( + train_image['image_id']) + + self.img_size = len(img_dict) + self.text_size = len(anno_info['annotations']) + + # return needed format data list + if self.test_mode: + return list(img_dict.values()) + return train_list diff --git a/mmpretrain/datasets/coco_vqa.py b/mmpretrain/datasets/coco_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..85f4bdcf39ef82ec47a2072dc198e6b8792d8768 --- /dev/null +++ b/mmpretrain/datasets/coco_vqa.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import re +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class COCOVQA(BaseDataset): + """VQAv2 dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + question_file: str, + ann_file: str = '', + **kwarg): + self.question_file = question_file + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.question_file) and self.question_file: + self.question_file = osp.join(self.data_root, self.question_file) + + return super()._join_prefix() + + def _create_image_index(self): + img_prefix = self.data_prefix['img_path'] + + files = mmengine.list_dir_or_file(img_prefix, list_dir=False) + image_index = {} + for file in files: + image_id = re.findall(r'\d{12}', file) + if len(image_id) > 0: + image_id = int(image_id[-1]) + image_index[image_id] = mmengine.join_path(img_prefix, file) + + return image_index + + def load_data_list(self) -> List[dict]: + """Load data list.""" + questions = mmengine.load(self.question_file)['questions'] + if self.ann_file: + annotations = mmengine.load(self.ann_file)['annotations'] + assert len(questions) == len(annotations) + else: + annotations = [None] * len(questions) + + # The original VQAv2 annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + data_list = [] + for question, ann in zip(questions, annotations): + # question example + # { + # 'image_id': 262144, + # 'question': "Is the ball flying towards the batter?", + # 'question_id': 262144000 + # } + # + # ann example + # { + # 'question_type': "what are the", + # 'answer_type': "other", + # 'answers': [ + # {'answer': 'watching', + # 'answer_id': 1, + # 'answer_confidence': 'yes'}, + # ... + # ], + # 'image_id': 262148, + # 'question_id': 262148000, + # 'multiple_choice_answer': 'watching', + # 'answer_type': 'other', + # } + + data_info = question + data_info['img_path'] = self.image_index[question['image_id']] + + if ann is not None: + assert ann['question_id'] == question['question_id'] + + # add answer_weight & answer_count, delete duplicate answer + answers = [item['answer'] for item in ann.pop('answers')] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + data_info.update(ann) + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/cub.py b/mmpretrain/datasets/cub.py new file mode 100644 index 0000000000000000000000000000000000000000..8db126216fb3408e2dd18255db04a851eb5fe08f --- /dev/null +++ b/mmpretrain/datasets/cub.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CUB_CATEGORIES + + +@DATASETS.register_module() +class CUB(BaseDataset): + """The CUB-200-2011 Dataset. + + Support the `CUB-200-2011 `_ Dataset. + Comparing with the `CUB-200 `_ Dataset, + there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset + directory structure is as follows. + + CUB dataset directory: :: + + CUB_200_2011 + ├── images + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── images.txt + ├── image_class_labels.txt + ├── train_test_split.txt + └── .... + + Args: + data_root (str): The root directory for CUB-200-2011 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import CUB + >>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train') + >>> train_dataset + Dataset CUB + Number of samples: 5994 + Number of categories: 200 + Root of dataset: data/CUB_200_2011 + >>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test') + >>> test_dataset + Dataset CUB + Number of samples: 5794 + Number of categories: 200 + Root of dataset: data/CUB_200_2011 + """ # noqa: E501 + + METAINFO = {'classes': CUB_CATEGORIES} + + def __init__(self, + data_root: str, + split: str = 'train', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + ann_file = 'images.txt' + data_prefix = 'images' + image_class_labels_file = 'image_class_labels.txt' + train_test_split_file = 'train_test_split.txt' + + self.backend = get_file_backend(data_root, enable_singleton=True) + self.image_class_labels_file = self.backend.join_path( + data_root, image_class_labels_file) + self.train_test_split_file = self.backend.join_path( + data_root, train_test_split_file) + super(CUB, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def _load_data_from_txt(self, filepath): + """load data from CUB txt file, the every line of the file is idx and a + data item.""" + pairs = list_from_file(filepath) + data_dict = dict() + for pair in pairs: + idx, data_item = pair.split() + # all the index starts from 1 in CUB files, + # here we need to '- 1' to let them start from 0. + data_dict[int(idx) - 1] = data_item + return data_dict + + def load_data_list(self): + """Load images and ground truth labels.""" + sample_dict = self._load_data_from_txt(self.ann_file) + + label_dict = self._load_data_from_txt(self.image_class_labels_file) + + split_dict = self._load_data_from_txt(self.train_test_split_file) + + assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\ + f'sample_ids should be same in files {self.ann_file}, ' \ + f'{self.image_class_labels_file} and {self.train_test_split_file}' + + data_list = [] + for sample_id in sample_dict.keys(): + if split_dict[sample_id] == '1' and self.split == 'test': + # skip train samples when split='test' + continue + elif split_dict[sample_id] == '0' and self.split == 'train': + # skip test samples when split='train' + continue + + img_path = self.backend.join_path(self.img_prefix, + sample_dict[sample_id]) + gt_label = int(label_dict[sample_id]) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/custom.py b/mmpretrain/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..bb491ff0cc7f816f629603d3b8be55e3f787c373 --- /dev/null +++ b/mmpretrain/datasets/custom.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +from mmengine.fileio import (BaseStorageBackend, get_file_backend, + list_from_file) +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +def find_folders( + root: str, + backend: Optional[BaseStorageBackend] = None +) -> Tuple[List[str], Dict[str, int]]: + """Find classes by folders under a root. + + Args: + root (string): root directory of folders + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[List[str], Dict[str, int]]: + + - folders: The name of sub folders under the root. + - folder_to_idx: The map from folder name to class idx. + """ + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + folders = list( + backend.list_dir_or_file( + root, + list_dir=True, + list_file=False, + recursive=False, + )) + folders.sort() + folder_to_idx = {folders[i]: i for i in range(len(folders))} + return folders, folder_to_idx + + +def get_samples( + root: str, + folder_to_idx: Dict[str, int], + is_valid_file: Callable, + backend: Optional[BaseStorageBackend] = None, +): + """Make dataset by walking all images under a root. + + Args: + root (string): root directory of folders + folder_to_idx (dict): the map from class name to class idx + is_valid_file (Callable): A function that takes path of a file + and check if the file is a valid sample file. + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[list, set]: + + - samples: a list of tuple where each element is (image, class_idx) + - empty_folders: The folders don't have any valid files. + """ + samples = [] + available_classes = set() + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + + if folder_to_idx is not None: + for folder_name in sorted(list(folder_to_idx.keys())): + _dir = backend.join_path(root, folder_name) + files = backend.list_dir_or_file( + _dir, + list_dir=False, + list_file=True, + recursive=True, + ) + for file in sorted(list(files)): + if is_valid_file(file): + path = backend.join_path(folder_name, file) + item = (path, folder_to_idx[folder_name]) + samples.append(item) + available_classes.add(folder_name) + empty_folders = set(folder_to_idx.keys()) - available_classes + else: + files = backend.list_dir_or_file( + root, + list_dir=False, + list_file=True, + recursive=True, + ) + samples = [file for file in sorted(list(files)) if is_valid_file(file)] + empty_folders = None + + return samples, empty_folders + + +@DATASETS.register_module() +class CustomDataset(BaseDataset): + """A generic dataset for multiple tasks. + + The dataset supports two kinds of style. + + 1. Use an annotation file to specify all samples, and each line indicates a + sample: + + The annotation file (for ``with_label=True``, supervised tasks.): :: + + folder_1/xxx.png 0 + folder_1/xxy.png 1 + 123.png 4 + nsdf3.png 3 + ... + + The annotation file (for ``with_label=False``, unsupervised tasks.): :: + + folder_1/xxx.png + folder_1/xxy.png + 123.png + nsdf3.png + ... + + Sample files: :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + ├── 123.png + ├── nsdf3.png + └── ... + + Please use the argument ``metainfo`` to specify extra information for + the task, like ``{'classes': ('bird', 'cat', 'deer', 'dog', 'frog')}``. + + 2. Place all samples in one folder as below: + + Sample files (for ``with_label=True``, supervised tasks, we use the name + of sub-folders as the categories names): :: + + data_prefix/ + ├── class_x + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + │ └── xxz.png + └── class_y + ├── 123.png + ├── nsdf3.png + ├── ... + └── asd932_.png + + Sample files (for ``with_label=False``, unsupervised tasks, we use all + sample files under the specified folder): :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + ├── 123.png + ├── nsdf3.png + └── ... + + If the ``ann_file`` is specified, the dataset will be generated by the + first way, otherwise, try the second way. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for the data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + with_label (bool): Whether the annotation file includes ground truth + labels, or use sub-folders to specify categories. + Defaults to True. + extensions (Sequence[str]): A sequence of allowed extensions. Defaults + to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + with_label=True, + extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', + '.bmp', '.pgm', '.tif'), + metainfo: Optional[dict] = None, + lazy_init: bool = False, + **kwargs): + assert (ann_file or data_prefix or data_root), \ + 'One of `ann_file`, `data_root` and `data_prefix` must '\ + 'be specified.' + + self.extensions = tuple(set([i.lower() for i in extensions])) + self.with_label = with_label + + super().__init__( + # The base class requires string ann_file but this class doesn't + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + # Force to lazy_init for some modification before loading data. + lazy_init=True, + **kwargs) + + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + def _find_samples(self): + """find samples from ``data_prefix``.""" + if self.with_label: + classes, folder_to_idx = find_folders(self.img_prefix) + samples, empty_classes = get_samples( + self.img_prefix, + folder_to_idx, + is_valid_file=self.is_valid_file, + ) + + self.folder_to_idx = folder_to_idx + + if self.CLASSES is not None: + assert len(self.CLASSES) == len(classes), \ + f"The number of subfolders ({len(classes)}) doesn't " \ + f'match the number of specified classes ' \ + f'({len(self.CLASSES)}). Please check the data folder.' + else: + self._metainfo['classes'] = tuple(classes) + else: + samples, empty_classes = get_samples( + self.img_prefix, + None, + is_valid_file=self.is_valid_file, + ) + + if len(samples) == 0: + raise RuntimeError( + f'Found 0 files in subfolders of: {self.data_prefix}. ' + f'Supported extensions are: {",".join(self.extensions)}') + + if empty_classes: + logger = MMLogger.get_current_instance() + logger.warning( + 'Found no valid file in the folder ' + f'{", ".join(empty_classes)}. ' + f"Supported extensions are: {', '.join(self.extensions)}") + + return samples + + def load_data_list(self): + """Load image paths and gt_labels.""" + if not self.ann_file: + samples = self._find_samples() + elif self.with_label: + lines = list_from_file(self.ann_file) + samples = [x.strip().rsplit(' ', 1) for x in lines] + else: + samples = list_from_file(self.ann_file) + + # Pre-build file backend to prevent verbose file backend inference. + backend = get_file_backend(self.img_prefix, enable_singleton=True) + data_list = [] + for sample in samples: + if self.with_label: + filename, gt_label = sample + img_path = backend.join_path(self.img_prefix, filename) + info = {'img_path': img_path, 'gt_label': int(gt_label)} + else: + img_path = backend.join_path(self.img_prefix, sample) + info = {'img_path': img_path} + data_list.append(info) + return data_list + + def is_valid_file(self, filename: str) -> bool: + """Check if a file is a valid sample.""" + return filename.lower().endswith(self.extensions) diff --git a/mmpretrain/datasets/dataset_wrappers.py b/mmpretrain/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..1adff10beb024940f9066a407cc76ddb06b27404 --- /dev/null +++ b/mmpretrain/datasets/dataset_wrappers.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import numpy as np +from mmengine.dataset import BaseDataset, force_full_init + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class KFoldDataset: + """A wrapper of dataset for K-Fold cross-validation. + + K-Fold cross-validation divides all the samples in groups of samples, + called folds, of almost equal sizes. And we use k-1 of folds to do training + and use the fold left to do validation. + + Args: + dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be + divided + fold (int): The fold used to do validation. Defaults to 0. + num_splits (int): The number of all folds. Defaults to 5. + test_mode (bool): Use the training dataset or validation dataset. + Defaults to False. + seed (int, optional): The seed to shuffle the dataset before splitting. + If None, not shuffle the dataset. Defaults to None. + """ + + def __init__(self, + dataset, + fold=0, + num_splits=5, + test_mode=False, + seed=None): + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + # Init the dataset wrapper lazily according to the dataset setting. + lazy_init = dataset.get('lazy_init', False) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError(f'Unsupported dataset type {type(dataset)}.') + + self._metainfo = getattr(self.dataset, 'metainfo', {}) + self.fold = fold + self.num_splits = num_splits + self.test_mode = test_mode + self.seed = seed + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of ``self.dataset``. + + Returns: + dict: Meta information of the dataset. + """ + # Prevent `self._metainfo` from being modified by outside. + return copy.deepcopy(self._metainfo) + + def full_init(self): + """fully initialize the dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + ori_len = len(self.dataset) + indices = list(range(ori_len)) + if self.seed is not None: + rng = np.random.default_rng(self.seed) + rng.shuffle(indices) + + test_start = ori_len * self.fold // self.num_splits + test_end = ori_len * (self.fold + 1) // self.num_splits + if self.test_mode: + indices = indices[test_start:test_end] + else: + indices = indices[:test_start] + indices[test_end:] + + self._ori_indices = indices + self.dataset = self.dataset.get_subset(indices) + + self._fully_initialized = True + + @force_full_init + def _get_ori_dataset_idx(self, idx: int) -> int: + """Convert global idx to local index. + + Args: + idx (int): Global index of ``KFoldDataset``. + + Returns: + int: The original index in the whole dataset. + """ + return self._ori_indices[idx] + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``KFoldDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def __len__(self): + return len(self.dataset) + + @force_full_init + def __getitem__(self, idx): + return self.dataset[idx] + + @force_full_init + def get_cat_ids(self, idx): + return self.dataset.get_cat_ids(idx) + + @force_full_init + def get_gt_labels(self): + return self.dataset.get_gt_labels() + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + type_ = 'test' if self.test_mode else 'training' + body.append(f'Type: \t{type_}') + body.append(f'Seed: \t{self.seed}') + + def ordinal(n): + # Copy from https://codegolf.stackexchange.com/a/74047 + suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4] + return f'{n}{suffix}' + + body.append( + f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold') + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + else: + body.append('The `CLASSES` meta info is not set.') + + body.append( + f'Original dataset type:\t{self.dataset.__class__.__name__}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmpretrain/datasets/dtd.py b/mmpretrain/datasets/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..034d0b1b444afebfc420eeff7e138072f7d7ee1f --- /dev/null +++ b/mmpretrain/datasets/dtd.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import DTD_CATEGORIES + + +@DATASETS.register_module() +class DTD(BaseDataset): + """The Describable Texture Dataset (DTD). + + Support the `Describable Texture Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + DTD dataset directory: :: + + dtd + ├── images + │ ├── banded + | | ├──banded_0002.jpg + | | ├──banded_0004.jpg + | | └── ... + │ └── ... + ├── imdb + │ └── imdb.mat + ├── labels + | | ├──labels_joint_anno.txt + | | ├──test1.txt + | | ├──test2.txt + | | └── ... + │ └── ... + └── .... + + Args: + data_root (str): The root directory for Describable Texture dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import DTD + >>> train_dataset = DTD(data_root='data/dtd', split='trainval') + >>> train_dataset + Dataset DTD + Number of samples: 3760 + Number of categories: 47 + Root of dataset: data/dtd + >>> test_dataset = DTD(data_root='data/dtd', split='test') + >>> test_dataset + Dataset DTD + Number of samples: 1880 + Number of categories: 47 + Root of dataset: data/dtd + """ # noqa: E501 + + METAINFO = {'classes': DTD_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + data_prefix = 'images' + test_mode = split == 'test' + + self.backend = get_file_backend(data_root, enable_singleton=True) + ann_file = self.backend.join_path('imdb', 'imdb.mat') + + super(DTD, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + data = mat4py.loadmat(self.ann_file)['images'] + names = data['name'] + labels = data['class'] + parts = data['set'] + num = len(names) + assert num == len(labels) == len(parts), 'get error ann file' + + if self.split == 'train': + target_set = {1} + elif self.split == 'val': + target_set = {2} + elif self.split == 'test': + target_set = {3} + else: + target_set = {1, 2} + + data_list = [] + for i in range(num): + if parts[i] in target_set: + img_name = names[i] + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/fgvcaircraft.py b/mmpretrain/datasets/fgvcaircraft.py new file mode 100644 index 0000000000000000000000000000000000000000..696992c06bbf02f097d017a519d42f758ba5f16f --- /dev/null +++ b/mmpretrain/datasets/fgvcaircraft.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FGVCAIRCRAFT_CATEGORIES + + +@DATASETS.register_module() +class FGVCAircraft(BaseDataset): + """The FGVC_Aircraft Dataset. + + Support the `FGVC_Aircraft Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + FGVC_Aircraft dataset directory: :: + + fgvc-aircraft-2013b + └── data + ├── images + │ ├── 1.jpg + │ ├── 2.jpg + │ └── ... + ├── images_variant_train.txt + ├── images_variant_test.txt + ├── images_variant_trainval.txt + ├── images_variant_val.txt + ├── variants.txt + └── .... + + Args: + data_root (str): The root directory for FGVC_Aircraft dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import FGVCAircraft + >>> train_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='trainval') + >>> train_dataset + Dataset FGVCAircraft + Number of samples: 6667 + Number of categories: 100 + Root of dataset: data/fgvc-aircraft-2013b + >>> test_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='test') + >>> test_dataset + Dataset FGVCAircraft + Number of samples: 3333 + Number of categories: 100 + Root of dataset: data/fgvc-aircraft-2013b + """ # noqa: E501 + + METAINFO = {'classes': FGVCAIRCRAFT_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + ann_file = self.backend.join_path('data', + f'images_variant_{split}.txt') + data_prefix = self.backend.join_path('data', 'images') + test_mode = split == 'test' + + super(FGVCAircraft, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + pair = pair.split() + img_name = pair[0] + class_name = ' '.join(pair[1:]) + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/flamingo.py b/mmpretrain/datasets/flamingo.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5745a1437537fccbc304d158a0f0c8d09f032a --- /dev/null +++ b/mmpretrain/datasets/flamingo.py @@ -0,0 +1,295 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from abc import abstractmethod +from collections import Counter +from typing import List + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS +from .coco_vqa import COCOVQA + + +class FlamingoFewShotMixin: + """Flamingo fewshot eval dataset minin. + + Args: + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + Note: 0 does not mean a strict zero-shot in Flamingo setting. + It will use 2 only-text prompt without in context images. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + incontext_prompt_temp (str): In context prompt template for few shot + examples. Defaults to ''. + final_prompt_temp (str): Final query prompt template. Defaults to ''. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + incontext_prompt_temp: str = '', + final_prompt_temp: str = '', + **kwarg): + self.num_shots = num_shots + self.num_support_examples = num_support_examples + self.num_query_examples = num_query_examples + self.incontext_prompt_temp = incontext_prompt_temp + self.final_prompt_temp = final_prompt_temp + super().__init__(**kwarg) + + def get_subset_idx(self, total_num): + random_idx = np.random.choice( + total_num, + self.num_support_examples + self.num_query_examples, + replace=False) + + support_idx = random_idx[:self.num_support_examples] + query_idx = random_idx[self.num_support_examples:] + return support_idx, query_idx + + @abstractmethod + def parse_basic_anno(self, anno: dict) -> dict: + """Parse basic annotation for support and query set.""" + pass + + @abstractmethod + def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list.""" + pass + + +@DATASETS.register_module() +class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA): + """Flamingo few shot VQAv2 dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + ann_file (str): Annotation file path. + question_file (str): Question file path. + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + Note: 0 does not mean a strict zero-shot in Flamingo setting. + It will use 2 only-text prompt without in context images. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + question_file: str, + ann_file: str = '', + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + **kwarg): + super().__init__( + data_root=data_root, + question_file=question_file, + ann_file=ann_file, + num_shots=num_shots, + num_support_examples=num_support_examples, + num_query_examples=num_query_examples, + **kwarg) + + def parse_basic_anno(self, ann: dict) -> dict: + """Parse basic annotation for support and query set. + + Args: + anno (dict): Annotation for single example. + + Return: + dict: Parsed annotation for single example. + """ + if ann is None: + return {} + + answers = [a['answer'] for a in ann['answers']] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + answer_info = { + 'gt_answer': list(count.keys()), + 'gt_answer_weight': answer_weight + } + return answer_info + + def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list. + + Args: + anno (dict): Annotation for single example. + support_list (List): List of support subset to subsample few shots. + + Return: + dict: Parsed annotation for single example. + """ + # prepare n shots examples + shots = random.sample(support_list, self.num_shots) + + # append image path for n shots + img_path = [shot['img_path'] for shot in shots] + img_path.append(query['img_path']) + query['img_path'] = img_path + + query['shots'] = [ + dict( + question=item['question'], + answer=item['gt_answer'][0], + ) for item in shots + ] + return query + + def load_data_list(self) -> List[dict]: + """Load data list.""" + questions = mmengine.load(self.question_file)['questions'] + if self.ann_file: + annotations = mmengine.load(self.ann_file)['annotations'] + assert len(questions) == len(annotations) + else: + annotations = [None] * len(questions) + if self.num_shots > 0: + raise ValueError('Unable to construct few-shot examples ' + 'since no annotation file.') + + # The original VQAv2 annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + num_data = len(questions) + support_idx, query_idx = self.get_subset_idx(num_data) + + # prepare support subset + if self.num_shots > 0: + support_list = [] + for idx in support_idx: + question = questions[idx] + ann = annotations[idx] + support = {**question, **self.parse_basic_anno(ann)} + support['img_path'] = self.image_index[question['image_id']] + support_list.append(support) + + # prepare query subset + data_list = [] + for idx in query_idx: + question = questions[idx] + ann = annotations[idx] + data_info = {**question, **self.parse_basic_anno(ann)} + data_info['img_path'] = self.image_index[question['image_id']] + if self.num_shots > 0: + data_info = self.parse_fewshot_anno(data_info, support_list) + data_list.append(data_info) + + return data_list + + +@DATASETS.register_module() +class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset): + """Flamingo few shot COCO Caption dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + ann_file: str, + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + **kwarg): + super().__init__( + data_root=data_root, + ann_file=ann_file, + num_shots=num_shots, + num_support_examples=num_support_examples, + num_query_examples=num_query_examples, + **kwarg) + + def parse_basic_anno(self, ann: dict, coco: COCO) -> dict: + """Parse basic annotation for support and query set. + + Args: + anno (dict): Annotation for single example. + coco (COCO): The coco dataset. + + Return: + dict: Parsed annotation for single example. + """ + img_prefix = self.data_prefix['img_path'] + img = coco.imgs[ann['image_id']] + data_info = dict( + img_path=mmengine.join_path(img_prefix, img['file_name']), + gt_caption=ann['caption'], + image_id=ann['image_id'], + ) + return data_info + + def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list. + + Args: + query (dict): Annotation for single example. + support_list (List): List of support subset to subsample few shots. + coco (COCO): The coco dataset. + + Return: + dict: Parsed annotation for single example. + """ + # prepare n shots examples + shots = random.sample(support_list, self.num_shots) + + # append image path for n shots + img_path = [shot['img_path'] for shot in shots] + img_path.append(query['img_path']) + query['img_path'] = img_path + + query['shots'] = [dict(caption=item['gt_caption']) for item in shots] + return query + + def load_data_list(self) -> List[dict]: + """Load data list.""" + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + + num_data = len(coco.anns) + support_idx, query_idx = self.get_subset_idx(num_data) + ann_ids = list(coco.anns) + + # prepare support subset + if self.num_shots > 0: + support_list = [] + for idx in support_idx: + support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) + support_list.append(support) + + # prepare query subset + query_list = [] + for idx in query_idx: + data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) + if self.num_shots > 0: + data_info = self.parse_fewshot_anno(data_info, support_list) + query_list.append(data_info) + + return query_list diff --git a/mmpretrain/datasets/flickr30k_caption.py b/mmpretrain/datasets/flickr30k_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f6841a2c87a0b3eaa3a7abd5b8fda1cb235bc0 --- /dev/null +++ b/mmpretrain/datasets/flickr30k_caption.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class Flickr30kCaption(BaseDataset): + """Flickr30k Caption dataset. To generate coco-style GT annotation for + evaluation, please refer to + tools/dataset_converters/convert_flickr30k_ann.py. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + + for img in annotations['images']: + + # img_example={ + # "sentids": [0, 1, 2], + # "imgid": 0, + # "sentences": [ + # {"raw": "Two men in green shirts standing in a yard.", + # "imgid": 0, "sentid": 0}, + # {"raw": "A man in a blue shirt standing in a garden.", + # "imgid": 0, "sentid": 1}, + # {"raw": "Two friends enjoy time spent together.", + # "imgid": 0, "sentid": 2} + # ], + # "split": "train", + # "filename": "1000092795.jpg" + # }, + + if img['split'] != self.split: + continue + + for sentence in img['sentences']: + data_info = { + 'image_id': img['imgid'], + 'img_path': file_backend.join_path(img_prefix, + img['filename']), + 'gt_caption': sentence['raw'] + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/flickr30k_retrieval.py b/mmpretrain/datasets/flickr30k_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..9f43c151b2079b3f72cf620577923efc57987316 --- /dev/null +++ b/mmpretrain/datasets/flickr30k_retrieval.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List + +import mmengine +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class Flickr30kRetrieval(BaseDataset): + """Flickr30k Retrieval dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + # get file backend + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + annotations = mmengine.load(self.ann_file) + + # mapping img_id to img filename + img_dict = OrderedDict() + img_idx = 0 + sentence_idx = 0 + train_list = [] + for img in annotations['images']: + + # img_example={ + # "sentids": [0, 1, 2], + # "imgid": 0, + # "sentences": [ + # {"raw": "Two men in green shirts standing in a yard.", + # "imgid": 0, "sentid": 0}, + # {"raw": "A man in a blue shirt standing in a garden.", + # "imgid": 0, "sentid": 1}, + # {"raw": "Two friends enjoy time spent together.", + # "imgid": 0, "sentid": 2} + # ], + # "split": "train", + # "filename": "1000092795.jpg" + # }, + + if img['split'] != self.split: + continue + + # create new idx for image + train_image = dict( + ori_id=img['imgid'], + image_id=img_idx, # used for evaluation + img_path=file_backend.join_path(img_prefix, img['filename']), + text=[], + gt_text_id=[], + gt_image_id=[], + ) + + for sentence in img['sentences']: + ann = {} + ann['text'] = sentence['raw'] + ann['ori_id'] = sentence['sentid'] + ann['text_id'] = sentence_idx # used for evaluation + + ann['image_ori_id'] = train_image['ori_id'] + ann['image_id'] = train_image['image_id'] + ann['img_path'] = train_image['img_path'] + ann['is_matched'] = True + + # 1. prepare train data list item + train_list.append(ann) + # 2. prepare eval data list item based on img dict + train_image['text'].append(ann['text']) + train_image['gt_text_id'].append(ann['text_id']) + train_image['gt_image_id'].append(ann['image_id']) + + sentence_idx += 1 + + img_dict[img['imgid']] = train_image + img_idx += 1 + + self.img_size = len(img_dict) + self.text_size = len(train_list) + + # return needed format data list + if self.test_mode: + return list(img_dict.values()) + return train_list diff --git a/mmpretrain/datasets/flowers102.py b/mmpretrain/datasets/flowers102.py new file mode 100644 index 0000000000000000000000000000000000000000..fe76dcc8422c8692261800b204a6262b60002e81 --- /dev/null +++ b/mmpretrain/datasets/flowers102.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class Flowers102(BaseDataset): + """The Oxford 102 Flower Dataset. + + Support the `Oxford 102 Flowers Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Flowers102 dataset directory: :: + + Flowers102 + ├── jpg + │ ├── image_00001.jpg + │ ├── image_00002.jpg + │ └── ... + ├── imagelabels.mat + ├── setid.mat + └── ... + + Args: + data_root (str): The root directory for Oxford 102 Flowers dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import Flowers102 + >>> train_dataset = Flowers102(data_root='data/Flowers102', split='trainval') + >>> train_dataset + Dataset Flowers102 + Number of samples: 2040 + Root of dataset: data/Flowers102 + >>> test_dataset = Flowers102(data_root='data/Flowers102', split='test') + >>> test_dataset + Dataset Flowers102 + Number of samples: 6149 + Root of dataset: data/Flowers102 + """ # noqa: E501 + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + ann_file = 'imagelabels.mat' + data_prefix = 'jpg' + train_test_split_file = 'setid.mat' + test_mode = split == 'test' + + self.backend = get_file_backend(data_root, enable_singleton=True) + + self.train_test_split_file = self.backend.join_path( + data_root, train_test_split_file) + + super(Flowers102, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + label_dict = mat4py.loadmat(self.ann_file)['labels'] + split_list = mat4py.loadmat(self.train_test_split_file) + + if self.split == 'train': + split_list = split_list['trnid'] + elif self.split == 'val': + split_list = split_list['valid'] + elif self.split == 'test': + split_list = split_list['tstid'] + else: + train_ids = split_list['trnid'] + val_ids = split_list['valid'] + train_ids.extend(val_ids) + split_list = train_ids + + data_list = [] + for sample_id in split_list: + img_name = 'image_%05d.jpg' % (sample_id) + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = int(label_dict[sample_id - 1]) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/food101.py b/mmpretrain/datasets/food101.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce7ffeee91c6843c259149770e9de4ad9f4317a --- /dev/null +++ b/mmpretrain/datasets/food101.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FOOD101_CATEGORIES + + +@DATASETS.register_module() +class Food101(BaseDataset): + """The Food101 Dataset. + + Support the `Food101 Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Food101 dataset directory: :: + + food-101 + ├── images + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── meta + │ ├── train.txt + │ └── test.txt + └── .... + + Args: + data_root (str): The root directory for Food101 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import Food101 + >>> train_dataset = Food101(data_root='data/food-101', split='train') + >>> train_dataset + Dataset Food101 + Number of samples: 75750 + Number of categories: 101 + Root of dataset: data/food-101 + >>> test_dataset = Food101(data_root='data/food-101', split='test') + >>> test_dataset + Dataset Food101 + Number of samples: 25250 + Number of categories: 101 + Root of dataset: data/food-101 + """ # noqa: E501 + + METAINFO = {'classes': FOOD101_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'train': + ann_file = self.backend.join_path('meta', 'train.txt') + else: + ann_file = self.backend.join_path('meta', 'test.txt') + + test_mode = split == 'test' + data_prefix = 'images' + + super(Food101, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + class_name, img_name = pair.split('/') + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, class_name, + img_name) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/gqa_dataset.py b/mmpretrain/datasets/gqa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..741791bc2bb51f768e8907aac7f002f0e730aeea --- /dev/null +++ b/mmpretrain/datasets/gqa_dataset.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class GQA(BaseDataset): + """GQA dataset. + + We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501 + + train: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501 + val: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501 + test: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501 + + and images from the official website: + https://cs.stanford.edu/people/dorarad/gqa/index.html + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # 'question': "Is it overcast?", + # 'answer': 'no, + # 'image_id': n161313.jpg, + # 'question_id': 262148000, + # .... + # } + data_info = dict() + data_info['img_path'] = osp.join(self.data_prefix['img_path'], + ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = ann['answer'] + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/imagenet.py b/mmpretrain/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..771d6ee454e3dc094962ca09036888f97ffb2d21 --- /dev/null +++ b/mmpretrain/datasets/imagenet.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +from mmengine import fileio +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .categories import IMAGENET_CATEGORIES +from .custom import CustomDataset + + +@DATASETS.register_module() +class ImageNet(CustomDataset): + """`ImageNet `_ Dataset. + + The dataset supports two kinds of directory format, + + :: + + imagenet + ├── train + │ ├──class_x + | | ├── x1.jpg + | | ├── x2.jpg + | | └── ... + │ ├── class_y + | | ├── y1.jpg + | | ├── y2.jpg + | | └── ... + | └── ... + ├── val + │ ├──class_x + | | └── ... + │ ├── class_y + | | └── ... + | └── ... + └── test + ├── test1.jpg + ├── test2.jpg + └── ... + + or :: + + imagenet + ├── train + │ ├── x1.jpg + │ ├── y1.jpg + │ └── ... + ├── val + │ ├── x3.jpg + │ ├── y3.jpg + │ └── ... + ├── test + │ ├── test1.jpg + │ ├── test2.jpg + │ └── ... + └── meta + ├── train.txt + └── val.txt + + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + split (str): The dataset split, supports "train", "val" and "test". + Default to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + + + Examples: + >>> from mmpretrain.datasets import ImageNet + >>> train_dataset = ImageNet(data_root='data/imagenet', split='train') + >>> train_dataset + Dataset ImageNet + Number of samples: 1281167 + Number of categories: 1000 + Root of dataset: data/imagenet + >>> test_dataset = ImageNet(data_root='data/imagenet', split='val') + >>> test_dataset + Dataset ImageNet + Number of samples: 50000 + Number of categories: 1000 + Root of dataset: data/imagenet + """ # noqa: E501 + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + METAINFO = {'classes': IMAGENET_CATEGORIES} + + def __init__(self, + data_root: str = '', + split: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + **kwargs): + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + + if split: + splits = ['train', 'val', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + + if split == 'test': + logger = MMLogger.get_current_instance() + logger.info( + 'Since the ImageNet1k test set does not provide label' + 'annotations, `with_label` is set to False') + kwargs['with_label'] = False + + data_prefix = split if data_prefix == '' else data_prefix + + if ann_file == '': + _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt') + if fileio.exists(_ann_path): + ann_file = fileio.join_path('meta', f'{split}.txt') + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body + + +@DATASETS.register_module() +class ImageNet21k(CustomDataset): + """ImageNet21k Dataset. + + Since the dataset ImageNet21k is extremely big, contains 21k+ classes + and 1.4B files. We won't provide the default categories list. Please + specify it from the ``classes`` argument. + The dataset directory structure is as follows, + + ImageNet21k dataset directory :: + + imagenet21k + ├── train + │ ├──class_x + | | ├── x1.jpg + | | ├── x2.jpg + | | └── ... + │ ├── class_y + | | ├── y1.jpg + | | ├── y2.jpg + | | └── ... + | └── ... + └── meta + └── train.txt + + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + multi_label (bool): Not implement by now. Use multi label or not. + Defaults to False. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import ImageNet21k + >>> train_dataset = ImageNet21k(data_root='data/imagenet21k', split='train') + >>> train_dataset + Dataset ImageNet21k + Number of samples: 14197088 + Annotation file: data/imagenet21k/meta/train.txt + Prefix of images: data/imagenet21k/train + """ # noqa: E501 + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + + def __init__(self, + data_root: str = '', + split: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + multi_label: bool = False, + **kwargs): + if multi_label: + raise NotImplementedError( + 'The `multi_label` option is not supported by now.') + self.multi_label = multi_label + + if split: + splits = ['train'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'.\ + If you want to specify your own validation set or test set,\ + please set split to None." + + self.split = split + data_prefix = split if data_prefix == '' else data_prefix + + if not ann_file: + _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt') + if fileio.exists(_ann_path): + ann_file = fileio.join_path('meta', f'{split}.txt') + + logger = MMLogger.get_current_instance() + + if not ann_file: + logger.warning( + 'The ImageNet21k dataset is large, and scanning directory may ' + 'consume long time. Considering to specify the `ann_file` to ' + 'accelerate the initialization.') + + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) + + if self.CLASSES is None: + logger.warning( + 'The CLASSES is not stored in the `ImageNet21k` class. ' + 'Considering to specify the `classes` argument if you need ' + 'do inference on the ImageNet-21k dataset') diff --git a/mmpretrain/datasets/inshop.py b/mmpretrain/datasets/inshop.py new file mode 100644 index 0000000000000000000000000000000000000000..f64f1779632d4a98d0e36d59750f4a1e8cbd4aed --- /dev/null +++ b/mmpretrain/datasets/inshop.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class InShop(BaseDataset): + """InShop Dataset for Image Retrieval. + + Please download the images from the homepage + 'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html' + (In-shop Clothes Retrieval Benchmark -> Img -> img.zip, + Eval/list_eval_partition.txt), and organize them as follows way: :: + + In-shop Clothes Retrieval Benchmark (data_root)/ + ├── Eval / + │ └── list_eval_partition.txt (ann_file) + ├── Img (img_prefix) + │ └── img/ + ├── README.txt + └── ..... + + Args: + data_root (str): The root directory for dataset. + split (str): Choose from 'train', 'query' and 'gallery'. + Defaults to 'train'. + data_prefix (str | dict): Prefix for training data. + Defaults to 'Img'. + ann_file (str): Annotation file path, path relative to + ``data_root``. Defaults to 'Eval/list_eval_partition.txt'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import InShop + >>> + >>> # build train InShop dataset + >>> inshop_train_cfg = dict(data_root='data/inshop', split='train') + >>> inshop_train = InShop(**inshop_train_cfg) + >>> inshop_train + Dataset InShop + Number of samples: 25882 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build query InShop dataset + >>> inshop_query_cfg = dict(data_root='data/inshop', split='query') + >>> inshop_query = InShop(**inshop_query_cfg) + >>> inshop_query + Dataset InShop + Number of samples: 14218 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build gallery InShop dataset + >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery') + >>> inshop_gallery = InShop(**inshop_gallery_cfg) + >>> inshop_gallery + Dataset InShop + Number of samples: 12612 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + """ + + def __init__(self, + data_root: str, + split: str = 'train', + data_prefix: str = 'Img', + ann_file: str = 'Eval/list_eval_partition.txt', + **kwargs): + + assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \ + f" must be one of ['train', 'query', 'gallery'], bu get '{split}'" + self.backend = get_file_backend(data_root, enable_singleton=True) + self.split = split + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + **kwargs) + + def _process_annotations(self): + lines = list_from_file(self.ann_file) + + anno_train = dict(metainfo=dict(), data_list=list()) + anno_gallery = dict(metainfo=dict(), data_list=list()) + + # item_id to label, each item corresponds to one class label + class_num = 0 + gt_label_train = {} + + # item_id to label, each label corresponds to several items + gallery_num = 0 + gt_label_gallery = {} + + # (lines[0], lines[1]) is the image number and the field name; + # Each line format as 'image_name, item_id, evaluation_status' + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'train': + if item_id not in gt_label_train: + gt_label_train[item_id] = class_num + class_num += 1 + # item_id to class_id (for the training set) + anno_train['data_list'].append( + dict(img_path=img_path, gt_label=gt_label_train[item_id])) + elif status == 'gallery': + if item_id not in gt_label_gallery: + gt_label_gallery[item_id] = [] + # Since there are multiple images for each item, + # record the corresponding item for each image. + gt_label_gallery[item_id].append(gallery_num) + anno_gallery['data_list'].append( + dict(img_path=img_path, sample_idx=gallery_num)) + gallery_num += 1 + + if self.split == 'train': + anno_train['metainfo']['class_number'] = class_num + anno_train['metainfo']['sample_number'] = \ + len(anno_train['data_list']) + return anno_train + elif self.split == 'gallery': + anno_gallery['metainfo']['sample_number'] = gallery_num + return anno_gallery + + # Generate the label for the query(val) set + anno_query = dict(metainfo=dict(), data_list=list()) + query_num = 0 + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'query': + anno_query['data_list'].append( + dict( + img_path=img_path, gt_label=gt_label_gallery[item_id])) + query_num += 1 + + anno_query['metainfo']['sample_number'] = query_num + return anno_query + + def load_data_list(self): + """load data list. + + For the train set, return image and ground truth label. For the query + set, return image and ids of images in gallery. For the gallery set, + return image and its id. + """ + data_info = self._process_annotations() + data_list = data_info['data_list'] + return data_list + + def extra_repr(self): + """The extra repr information of the dataset.""" + body = [f'Root of dataset: \t{self.data_root}'] + return body diff --git a/mmpretrain/datasets/mnist.py b/mmpretrain/datasets/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..425267fe8034860d3b78c6af5b565ddb6efc7c10 --- /dev/null +++ b/mmpretrain/datasets/mnist.py @@ -0,0 +1,234 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import codecs +from typing import List, Optional +from urllib.parse import urljoin + +import mmengine.dist as dist +import numpy as np +import torch +from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES +from .utils import (download_and_extract_archive, open_maybe_compressed_file, + rm_suffix) + + +@DATASETS.register_module() +class MNIST(BaseDataset): + """`MNIST `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py + + Args: + data_root (str): The root directory of the MNIST Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ # noqa: E501 + + url_prefix = 'http://yann.lecun.com/exdb/mnist/' + # train images and labels + train_list = [ + ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'], + ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'], + ] + # test images and labels + test_list = [ + ['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'], + ['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'], + ] + METAINFO = {'classes': MNIST_CATEGORITES} + + def __init__(self, + data_root: str = '', + split: str = 'train', + metainfo: Optional[dict] = None, + download: bool = True, + data_prefix: str = '', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + if not data_root and not data_prefix: + raise RuntimeError('Please set ``data_root`` to' + 'specify the dataset path') + + self.download = download + super().__init__( + # The MNIST dataset doesn't need specify annotation file + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=dict(root=data_prefix), + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) + + if dist.is_main_process() and not self._check_exists(): + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') + + if self.download: + self._download() + else: + raise RuntimeError( + f'Cannot find {self.__class__.__name__} dataset in ' + f"{self.data_prefix['root']}, you can specify " + '`download=True` to download automatically.') + + dist.barrier() + assert self._check_exists(), \ + 'Download failed or shared storage is unavailable. Please ' \ + f'download the dataset manually through {self.url_prefix}.' + + if not self.test_mode: + file_list = self.train_list + else: + file_list = self.test_list + + # load data from SN3 files + imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0]))) + gt_labels = read_label_file( + join_path(root, rm_suffix(file_list[1][0]))) + + data_infos = [] + for img, gt_label in zip(imgs, gt_labels): + gt_label = np.array(gt_label, dtype=np.int64) + info = {'img': img.numpy(), 'gt_label': gt_label} + data_infos.append(info) + return data_infos + + def _check_exists(self): + """Check the exists of data files.""" + root = self.data_prefix['root'] + + for filename, _ in (self.train_list + self.test_list): + # get extracted filename of data + extract_filename = rm_suffix(filename) + fpath = join_path(root, extract_filename) + if not exists(fpath): + return False + return True + + def _download(self): + """Download and extract data files.""" + root = self.data_prefix['root'] + + for filename, md5 in (self.train_list + self.test_list): + url = urljoin(self.url_prefix, filename) + download_and_extract_archive( + url, download_root=root, filename=filename, md5=md5) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [f"Prefix of data: \t{self.data_prefix['root']}"] + return body + + +@DATASETS.register_module() +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ + Dataset. + + Args: + data_root (str): The root directory of the MNIST Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + # train images and labels + train_list = [ + ['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'], + ['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'], + ] + # test images and labels + test_list = [ + ['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'], + ['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'], + ] + METAINFO = {'classes': FASHIONMNIST_CATEGORITES} + + +def get_int(b: bytes) -> int: + """Convert bytes to int.""" + return int(codecs.encode(b, 'hex'), 16) + + +def read_sn3_pascalvincent_tensor(path: str, + strict: bool = True) -> torch.Tensor: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx- + io.lsh'). + + Argument may be a filename, compressed filename, or file object. + """ + # typemap + if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): + read_sn3_pascalvincent_tensor.typemap = { + 8: (torch.uint8, np.uint8, np.uint8), + 9: (torch.int8, np.int8, np.int8), + 11: (torch.int16, np.dtype('>i2'), 'i2'), + 12: (torch.int32, np.dtype('>i4'), 'i4'), + 13: (torch.float32, np.dtype('>f4'), 'f4'), + 14: (torch.float64, np.dtype('>f8'), 'f8') + } + # read + with open_maybe_compressed_file(path) as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert nd >= 1 and nd <= 3 + assert ty >= 8 and ty <= 14 + m = read_sn3_pascalvincent_tensor.typemap[ty] + s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)] + parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) + assert parsed.shape[0] == np.prod(s) or not strict + return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) + + +def read_label_file(path: str) -> torch.Tensor: + """Read labels from SN3 label file.""" + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert (x.dtype == torch.uint8) + assert (x.ndimension() == 1) + return x.long() + + +def read_image_file(path: str) -> torch.Tensor: + """Read images from SN3 image file.""" + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert (x.dtype == torch.uint8) + assert (x.ndimension() == 3) + return x diff --git a/mmpretrain/datasets/multi_label.py b/mmpretrain/datasets/multi_label.py new file mode 100644 index 0000000000000000000000000000000000000000..58a9c7cd5f097689d29700004e2ed815934a1594 --- /dev/null +++ b/mmpretrain/datasets/multi_label.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class MultiLabelDataset(BaseDataset): + """Multi-label Dataset. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + The annotation format is shown as follows. + + .. code-block:: none + + { + "metainfo": + { + "classes":['A', 'B', 'C'....] + }, + "data_list": + [ + { + "img_path": "test_img1.jpg", + 'gt_label': [0, 1], + }, + { + "img_path": "test_img2.jpg", + 'gt_label': [2], + }, + ] + .... + } + + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category ids by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image categories of specified index. + """ + return self.get_data_info(idx)['gt_label'] diff --git a/mmpretrain/datasets/multi_task.py b/mmpretrain/datasets/multi_task.py new file mode 100644 index 0000000000000000000000000000000000000000..443df0e7d7de11962d472d33b25b4bbff562524f --- /dev/null +++ b/mmpretrain/datasets/multi_task.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from os import PathLike +from typing import Optional, Sequence + +import mmengine +from mmcv.transforms import Compose +from mmengine.fileio import get_file_backend + +from .builder import DATASETS + + +def expanduser(path): + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +def isabs(uri): + return osp.isabs(uri) or ('://' in uri) + + +@DATASETS.register_module() +class MultiTaskDataset: + """Custom dataset for multi-task dataset. + + To use the dataset, please generate and provide an annotation file in the + below format: + + .. code-block:: json + + { + "metainfo": { + "tasks": + [ + 'gender' + 'wear' + ] + }, + "data_list": [ + { + "img_path": "a.jpg", + gt_label:{ + "gender": 0, + "wear": [1, 0, 1, 0] + } + }, + { + "img_path": "b.jpg", + gt_label:{ + "gender": 1, + "wear": [1, 0, 1, 0] + } + } + ] + } + + Assume we put our dataset in the ``data/mydataset`` folder in the + repository and organize it as the below format: :: + + mmpretrain/ + └── data + └── mydataset + ├── annotation + │   ├── train.json + │   ├── test.json + │   └── val.json + ├── train + │   ├── a.jpg + │   └── ... + ├── test + │   ├── b.jpg + │   └── ... + └── val + ├── c.jpg + └── ... + + We can use the below config to build datasets: + + .. code:: python + + >>> from mmpretrain.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="annotation/train.json", + ... data_root="data/mydataset", + ... # The `img_path` field in the train annotation file is relative + ... # to the `train` folder. + ... data_prefix='train', + ... ) + >>> train_dataset = build_dataset(train_cfg) + + Or we can put all files in the same folder: :: + + mmpretrain/ + └── data + └── mydataset + ├── train.json + ├── test.json + ├── val.json + ├── a.jpg + ├── b.jpg + ├── c.jpg + └── ... + + And we can use the below config to build datasets: + + .. code:: python + + >>> from mmpretrain.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="train.json", + ... data_root="data/mydataset", + ... # the `data_prefix` is not required since all paths are + ... # relative to the `data_root`. + ... ) + >>> train_dataset = build_dataset(train_cfg) + + + Args: + ann_file (str): The annotation file path. It can be either absolute + path or relative path to the ``data_root``. + metainfo (dict, optional): The extra meta information. It should be + a dict with the same format as the ``"metainfo"`` field in the + annotation file. Defaults to None. + data_root (str, optional): The root path of the data directory. It's + the prefix of the ``data_prefix`` and the ``ann_file``. And it can + be a remote path like "s3://openmmlab/xxx/". Defaults to None. + data_prefix (str, optional): The base folder relative to the + ``data_root`` for the ``"img_path"`` field in the annotation file. + Defaults to None. + pipeline (Sequence[dict]): A list of dict, where each element + represents a operation defined in + :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple. + test_mode (bool): in train mode or test mode. Defaults to False. + """ + METAINFO = dict() + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: Optional[str] = None, + pipeline: Sequence = (), + test_mode: bool = False): + + self.data_root = expanduser(data_root) + + # Inference the file client + if self.data_root is not None: + self.file_backend = get_file_backend(uri=self.data_root) + else: + self.file_backend = None + + self.ann_file = self._join_root(expanduser(ann_file)) + self.data_prefix = self._join_root(data_prefix) + + self.test_mode = test_mode + self.pipeline = Compose(pipeline) + self.data_list = self.load_data_list(self.ann_file, metainfo) + + def _join_root(self, path): + """Join ``self.data_root`` with the specified path. + + If the path is an absolute path, just return the path. And if the + path is None, return ``self.data_root``. + + Examples: + >>> self.data_root = 'a/b/c' + >>> self._join_root('d/e/') + 'a/b/c/d/e' + >>> self._join_root('https://openmmlab.com') + 'https://openmmlab.com' + >>> self._join_root(None) + 'a/b/c' + """ + if path is None: + return self.data_root + if isabs(path): + return path + + joined_path = self.file_backend.join_path(self.data_root, path) + return joined_path + + @classmethod + def _get_meta_info(cls, in_metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + in_metainfo (dict): Meta information dict. + + Returns: + dict: Parsed meta information. + """ + # `cls.METAINFO` will be overwritten by in_meta + metainfo = copy.deepcopy(cls.METAINFO) + if in_metainfo is None: + return metainfo + + metainfo.update(in_metainfo) + + return metainfo + + def load_data_list(self, ann_file, metainfo_override=None): + """Load annotations from an annotation file. + + Args: + ann_file (str): Absolute annotation file path if ``self.root=None`` + or relative path if ``self.root=/path/to/data/``. + + Returns: + list[dict]: A list of annotation. + """ + annotations = mmengine.load(ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + if 'data_list' not in annotations: + raise ValueError('The annotation file must have the `data_list` ' + 'field.') + metainfo = annotations.get('metainfo', {}) + raw_data_list = annotations['data_list'] + + # Set meta information. + assert isinstance(metainfo, dict), 'The `metainfo` field in the '\ + f'annotation file should be a dict, but got {type(metainfo)}' + if metainfo_override is not None: + assert isinstance(metainfo_override, dict), 'The `metainfo` ' \ + f'argument should be a dict, but got {type(metainfo_override)}' + metainfo.update(metainfo_override) + self._metainfo = self._get_meta_info(metainfo) + + data_list = [] + for i, raw_data in enumerate(raw_data_list): + try: + data_list.append(self.parse_data_info(raw_data)) + except AssertionError as e: + raise RuntimeError( + f'The format check fails during parse the item {i} of ' + f'the annotation file with error: {e}') + return data_list + + def parse_data_info(self, raw_data): + """Parse raw annotation to target format. + + This method will return a dict which contains the data information of a + sample. + + Args: + raw_data (dict): Raw data information load from ``ann_file`` + + Returns: + dict: Parsed annotation. + """ + assert isinstance(raw_data, dict), \ + f'The item should be a dict, but got {type(raw_data)}' + assert 'img_path' in raw_data, \ + "The item doesn't have `img_path` field." + data = dict( + img_path=self._join_root(raw_data['img_path']), + gt_label=raw_data['gt_label'], + ) + return data + + @property + def metainfo(self) -> dict: + """Get meta information of dataset. + + Returns: + dict: meta information collected from ``cls.METAINFO``, + annotation file and metainfo argument during instantiation. + """ + return copy.deepcopy(self._metainfo) + + def prepare_data(self, idx): + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + results = copy.deepcopy(self.data_list[idx]) + return self.pipeline(results) + + def __len__(self): + """Get the length of the whole dataset. + + Returns: + int: The length of filtered dataset. + """ + return len(self.data_list) + + def __getitem__(self, idx): + """Get the idx-th image and data information of dataset after + ``self.pipeline``. + + Args: + idx (int): The index of of the data. + + Returns: + dict: The idx-th image and data information after + ``self.pipeline``. + """ + return self.prepare_data(idx) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [f'Number of samples: \t{self.__len__()}'] + if self.data_root is not None: + body.append(f'Root location: \t{self.data_root}') + body.append(f'Annotation file: \t{self.ann_file}') + if self.data_prefix is not None: + body.append(f'Prefix of images: \t{self.data_prefix}') + # -------------------- extra repr -------------------- + tasks = self.metainfo['tasks'] + body.append(f'For {len(tasks)} tasks') + for task in tasks: + body.append(f' {task} ') + # ---------------------------------------------------- + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmpretrain/datasets/nlvr2.py b/mmpretrain/datasets/nlvr2.py new file mode 100644 index 0000000000000000000000000000000000000000..0063090657714406049a6daa6fa3c0d868422590 --- /dev/null +++ b/mmpretrain/datasets/nlvr2.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import List + +from mmengine.fileio import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class NLVR2(BaseDataset): + """COCO Caption dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list.""" + + data_list = [] + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + examples = list_from_file(self.ann_file) + + for example in examples: + example = json.loads(example) + prefix = example['identifier'].rsplit('-', 1)[0] + train_data = {} + train_data['text'] = example['sentence'] + train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']] + train_data['img_path'] = [ + file_backend.join_path(img_prefix, prefix + f'-img{i}.png') + for i in range(2) + ] + + data_list.append(train_data) + + return data_list diff --git a/mmpretrain/datasets/nocaps.py b/mmpretrain/datasets/nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..65116e9cecc2d9983ef72ca3eee24ff7baedacc0 --- /dev/null +++ b/mmpretrain/datasets/nocaps.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class NoCaps(BaseDataset): + """NoCaps dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``.. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + + file_backend = get_file_backend(img_prefix) + data_list = [] + for ann in coco.anns.values(): + image_id = ann['image_id'] + image_path = file_backend.join_path( + img_prefix, coco.imgs[image_id]['file_name']) + data_info = { + 'image_id': image_id, + 'img_path': image_path, + 'gt_caption': None + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/ocr_vqa.py b/mmpretrain/datasets/ocr_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..55aa6913e3c4464444e8b971ccabf68aa2d99904 --- /dev/null +++ b/mmpretrain/datasets/ocr_vqa.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class OCRVQA(BaseDataset): + """OCR-VQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + + split_dict = {1: 'train', 2: 'val', 3: 'test'} + + annotations = mmengine.load(self.ann_file) + + # ann example + # "761183272": { + # "imageURL": \ + # "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg", + # "questions": [ + # "Who wrote this book?", + # "What is the title of this book?", + # "What is the genre of this book?", + # "Is this a games related book?", + # "What is the year printed on this calendar?"], + # "answers": [ + # "Sandra Boynton", + # "Mom's Family Wall Calendar 2016", + # "Calendars", + # "No", + # "2016"], + # "title": "Mom's Family Wall Calendar 2016", + # "authorName": "Sandra Boynton", + # "genre": "Calendars", + # "split": 1 + # }, + + data_list = [] + + for key, ann in annotations.items(): + if self.split != split_dict[ann['split']]: + continue + + extension = osp.splitext(ann['imageURL'])[1] + if extension not in ['.jpg', '.png']: + continue + img_path = mmengine.join_path(self.data_prefix['img_path'], + key + extension) + for question, answer in zip(ann['questions'], ann['answers']): + data_info = {} + data_info['img_path'] = img_path + data_info['question'] = question + data_info['gt_answer'] = answer + data_info['gt_answer_weight'] = [1.0] + + data_info['imageURL'] = ann['imageURL'] + data_info['title'] = ann['title'] + data_info['authorName'] = ann['authorName'] + data_info['genre'] = ann['genre'] + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/oxfordiiitpet.py b/mmpretrain/datasets/oxfordiiitpet.py new file mode 100644 index 0000000000000000000000000000000000000000..23c8b7db8679e99c6ed2698b9eb140cd6151d445 --- /dev/null +++ b/mmpretrain/datasets/oxfordiiitpet.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import OxfordIIITPet_CATEGORIES + + +@DATASETS.register_module() +class OxfordIIITPet(BaseDataset): + """The Oxford-IIIT Pets Dataset. + + Support the `Oxford-IIIT Pets Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Oxford-IIIT_Pets dataset directory: :: + + Oxford-IIIT_Pets + ├── images + │ ├── Abyssinian_1.jpg + │ ├── Abyssinian_2.jpg + │ └── ... + ├── annotations + │ ├── trainval.txt + │ ├── test.txt + │ ├── list.txt + │ └── ... + └── .... + + Args: + data_root (str): The root directory for Oxford-IIIT Pets dataset. + split (str, optional): The dataset split, supports "trainval" and "test". + Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import OxfordIIITPet + >>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval') + >>> train_dataset + Dataset OxfordIIITPet + Number of samples: 3680 + Number of categories: 37 + Root of dataset: data/Oxford-IIIT_Pets + >>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test') + >>> test_dataset + Dataset OxfordIIITPet + Number of samples: 3669 + Number of categories: 37 + Root of dataset: data/Oxford-IIIT_Pets + """ # noqa: E501 + + METAINFO = {'classes': OxfordIIITPet_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'trainval': + ann_file = self.backend.join_path('annotations', 'trainval.txt') + else: + ann_file = self.backend.join_path('annotations', 'test.txt') + + data_prefix = 'images' + test_mode = split == 'test' + + super(OxfordIIITPet, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + img_name, class_id, _, _ = pair.split() + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = int(class_id) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/places205.py b/mmpretrain/datasets/places205.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ba1ff631a7a4840b66cf63ec53585ec064560d --- /dev/null +++ b/mmpretrain/datasets/places205.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +from mmpretrain.registry import DATASETS +from .categories import PLACES205_CATEGORIES +from .custom import CustomDataset + + +@DATASETS.register_module() +class Places205(CustomDataset): + """`Places205 `_ Dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults + to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + """ + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + METAINFO = {'classes': PLACES205_CATEGORIES} + + def __init__(self, + data_root: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + **kwargs): + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) diff --git a/mmpretrain/datasets/refcoco.py b/mmpretrain/datasets/refcoco.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f2a943f73fdab493a47bbcd1d0ea6385ec60fa --- /dev/null +++ b/mmpretrain/datasets/refcoco.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class RefCOCO(BaseDataset): + """RefCOCO dataset. + + Args: + ann_file (str): Annotation file path. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str): Prefix for training data. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root, + ann_file, + data_prefix, + split_file, + split='train', + **kwargs): + self.split_file = split_file + self.split = split + + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwargs, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.split_file) and self.split_file: + self.split_file = osp.join(self.data_root, self.split_file) + + return super()._join_prefix() + + def load_data_list(self) -> List[dict]: + """Load data list.""" + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + splits = mmengine.load(self.split_file, file_format='pkl') + img_prefix = self.data_prefix['img_path'] + + data_list = [] + join_path = mmengine.fileio.get_file_backend(img_prefix).join_path + for refer in splits: + if refer['split'] != self.split: + continue + + ann = coco.anns[refer['ann_id']] + img = coco.imgs[ann['image_id']] + sentences = refer['sentences'] + bbox = np.array(ann['bbox'], dtype=np.float32) + bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY + + for sent in sentences: + data_info = { + 'img_path': join_path(img_prefix, img['file_name']), + 'image_id': ann['image_id'], + 'ann_id': ann['id'], + 'text': sent['sent'], + 'gt_bboxes': bbox[None, :], + } + data_list.append(data_info) + + if len(data_list) == 0: + raise ValueError(f'No sample in split "{self.split}".') + + return data_list diff --git a/mmpretrain/datasets/samplers/__init__.py b/mmpretrain/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bccf9c34659e19764871a696260cf5884696ca1 --- /dev/null +++ b/mmpretrain/datasets/samplers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .repeat_aug import RepeatAugSampler +from .sequential import SequentialSampler + +__all__ = ['RepeatAugSampler', 'SequentialSampler'] diff --git a/mmpretrain/datasets/samplers/__pycache__/__init__.cpython-38.pyc b/mmpretrain/datasets/samplers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9672735d3373a8d9e4d0859cd30310e799556d05 Binary files /dev/null and b/mmpretrain/datasets/samplers/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/datasets/samplers/__pycache__/repeat_aug.cpython-38.pyc b/mmpretrain/datasets/samplers/__pycache__/repeat_aug.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de91371076f2d2a1df328f5b24cf9c3aea945235 Binary files /dev/null and b/mmpretrain/datasets/samplers/__pycache__/repeat_aug.cpython-38.pyc differ diff --git a/mmpretrain/datasets/samplers/__pycache__/sequential.cpython-38.pyc b/mmpretrain/datasets/samplers/__pycache__/sequential.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c150cc6be8d91701e24d7e2662e5567108004164 Binary files /dev/null and b/mmpretrain/datasets/samplers/__pycache__/sequential.cpython-38.pyc differ diff --git a/mmpretrain/datasets/samplers/repeat_aug.py b/mmpretrain/datasets/samplers/repeat_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..d833a1954d7d9d181c368d5b3b956c25df241c1a --- /dev/null +++ b/mmpretrain/datasets/samplers/repeat_aug.py @@ -0,0 +1,101 @@ +import math +from typing import Iterator, Optional, Sized + +import torch +from mmengine.dist import get_dist_info, is_main_process, sync_random_seed +from torch.utils.data import Sampler + +from mmpretrain.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class RepeatAugSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset for + distributed, with repeated augmentation. It ensures that different each + augmented version of a sample will be visible to a different process (GPU). + Heavily based on torch.utils.data.DistributedSampler. + + This sampler was taken from + https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py + Used in + Copyright (c) 2015-present, Facebook, Inc. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + num_repeats (int): The repeat times of every sample. Defaults to 3. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + """ + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + num_repeats: int = 3, + seed: Optional[int] = None): + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if not self.shuffle and is_main_process(): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning('The RepeatAugSampler always picks a ' + 'fixed part of data if `shuffle=False`.') + + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.num_repeats = num_repeats + + # The number of repeated samples in the rank + self.num_samples = math.ceil( + len(self.dataset) * num_repeats / world_size) + # The total number of repeated samples in all ranks. + self.total_size = self.num_samples * world_size + # The number of selected samples in the rank + self.num_selected_samples = math.ceil(len(self.dataset) / world_size) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] + indices = [x for x in indices for _ in range(self.num_repeats)] + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + indices += indices[:padding_size] + assert len(indices) == self.total_size + + # subsample per rank + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + + # return up to num selected samples + return iter(indices[:self.num_selected_samples]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_selected_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/mmpretrain/datasets/samplers/sequential.py b/mmpretrain/datasets/samplers/sequential.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b940c2eabc2ab9c2401cd1923776fc067e9f6c --- /dev/null +++ b/mmpretrain/datasets/samplers/sequential.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterator + +import torch +from mmengine.dataset import DefaultSampler + +from mmpretrain.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class SequentialSampler(DefaultSampler): + """Sequential sampler which supports different subsample policy. + + Args: + dataset (Sized): The dataset. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + subsample_type (str): The method to subsample data on different rank. + Supported type: + + - ``'default'``: Original torch behavior. Sample the examples one + by one for each GPU in terms. For instance, 8 examples on 2 GPUs, + GPU0: [0,2,4,8], GPU1: [1,3,5,7] + - ``'sequential'``: Subsample all examples to n chunk sequntially. + For instance, 8 examples on 2 GPUs, + GPU0: [0,1,2,3], GPU1: [4,5,6,7] + """ + + def __init__(self, subsample_type: str = 'default', **kwargs) -> None: + super().__init__(shuffle=False, **kwargs) + + if subsample_type not in ['default', 'sequential']: + raise ValueError(f'Unsupported subsample typer "{subsample_type}",' + ' please choose from ["default", "sequential"]') + self.subsample_type = subsample_type + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + if self.subsample_type == 'default': + indices = indices[self.rank:self.total_size:self.world_size] + elif self.subsample_type == 'sequential': + num_samples_per_rank = self.total_size // self.world_size + indices = indices[self.rank * + num_samples_per_rank:(self.rank + 1) * + num_samples_per_rank] + + return iter(indices) diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py new file mode 100644 index 0000000000000000000000000000000000000000..8e442491be85540980c0309b65d32a12c9c85542 --- /dev/null +++ b/mmpretrain/datasets/scienceqa.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Callable, List, Sequence + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class ScienceQA(BaseDataset): + """ScienceQA dataset. + + This dataset is used to load the multimodal data of ScienceQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + split (str): The split of dataset. Options: ``train``, ``val``, + ``test``, ``trainval``, ``minival``, and ``minitest``. + split_file (str): The split file of dataset, which contains the + ids of data samples in the split. + ann_file (str): Annotation file path. + image_only (bool): Whether only to load data with image. Defaults to + False. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + split: str, + split_file: str, + ann_file: str, + image_only: bool = False, + data_prefix: dict = dict(img_path=''), + pipeline: Sequence[Callable] = (), + **kwargs): + assert split in [ + 'train', 'val', 'test', 'trainval', 'minival', 'minitest' + ], f'Invalid split {split}' + self.split = split + self.split_file = os.path.join(data_root, split_file) + self.image_only = image_only + + super().__init__( + data_root=data_root, + ann_file=ann_file, + data_prefix=data_prefix, + pipeline=pipeline, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + current_data_split = mmengine.load(self.split_file)[self.split] # noqa + + file_backend = get_file_backend(img_prefix) + + data_list = [] + for data_id in current_data_split: + ann = annotations[data_id] + if self.image_only and ann['image'] is None: + continue + data_info = { + 'image_id': + data_id, + 'question': + ann['question'], + 'choices': + ann['choices'], + 'gt_answer': + ann['answer'], + 'hint': + ann['hint'], + 'image_name': + ann['image'], + 'task': + ann['task'], + 'grade': + ann['grade'], + 'subject': + ann['subject'], + 'topic': + ann['topic'], + 'category': + ann['category'], + 'skill': + ann['skill'], + 'lecture': + ann['lecture'], + 'solution': + ann['solution'], + 'split': + ann['split'], + 'img_path': + file_backend.join_path(img_prefix, data_id, ann['image']) + if ann['image'] is not None else None, + 'has_image': + True if ann['image'] is not None else False, + } + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/stanfordcars.py b/mmpretrain/datasets/stanfordcars.py new file mode 100644 index 0000000000000000000000000000000000000000..355697943cf693869f35f2a0bd71abdfa0396722 --- /dev/null +++ b/mmpretrain/datasets/stanfordcars.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import STANFORDCARS_CATEGORIES + + +@DATASETS.register_module() +class StanfordCars(BaseDataset): + """The Stanford Cars Dataset. + + Support the `Stanford Cars Dataset `_ Dataset. + The official website provides two ways to organize the dataset. + Therefore, after downloading and decompression, the dataset directory structure is as follows. + + Stanford Cars dataset directory: :: + + Stanford_Cars + ├── car_ims + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── cars_annos.mat + + or :: + + Stanford_Cars + ├── cars_train + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + ├── cars_test + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── devkit + ├── cars_meta.mat + ├── cars_train_annos.mat + ├── cars_test_annos.mat + ├── cars_test_annoswithlabels.mat + ├── eval_train.m + └── train_perfect_preds.txt + + Args: + data_root (str): The root directory for Stanford Cars dataset. + split (str, optional): The dataset split, supports "train" + and "test". Default to "train". + + Examples: + >>> from mmpretrain.datasets import StanfordCars + >>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train') + >>> train_dataset + Dataset StanfordCars + Number of samples: 8144 + Number of categories: 196 + Root of dataset: data/Stanford_Cars + >>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test') + >>> test_dataset + Dataset StanfordCars + Number of samples: 8041 + Number of categories: 196 + Root of dataset: data/Stanford_Cars + """ # noqa: E501 + + METAINFO = {'classes': STANFORDCARS_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + test_mode = split == 'test' + self.backend = get_file_backend(data_root, enable_singleton=True) + + anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat') + if self.backend.exists(anno_file_path): + ann_file = 'cars_annos.mat' + data_prefix = '' + else: + if test_mode: + ann_file = self.backend.join_path( + 'devkit', 'cars_test_annos_withlabels.mat') + data_prefix = 'cars_test' + else: + ann_file = self.backend.join_path('devkit', + 'cars_train_annos.mat') + data_prefix = 'cars_train' + + if not self.backend.exists( + self.backend.join_path(data_root, ann_file)): + doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501 + raise RuntimeError( + f'The dataset is incorrectly organized, please \ + refer to {doc_url} and reorganize your folders.') + + super(StanfordCars, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + data = mat4py.loadmat(self.ann_file)['annotations'] + + data_list = [] + if 'test' in data.keys(): + # first way + img_paths, labels, test = data['relative_im_path'], data[ + 'class'], data['test'] + num = len(img_paths) + assert num == len(labels) == len(test), 'get error ann file' + for i in range(num): + if not self.test_mode and test[i] == 1: + continue + if self.test_mode and test[i] == 0: + continue + img_path = self.backend.join_path(self.img_prefix, + img_paths[i]) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + else: + # second way + img_names, labels = data['fname'], data['class'] + num = len(img_names) + assert num == len(labels), 'get error ann file' + for i in range(num): + img_path = self.backend.join_path(self.img_prefix, + img_names[i]) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/sun397.py b/mmpretrain/datasets/sun397.py new file mode 100644 index 0000000000000000000000000000000000000000..1039a0690f8096082d5c55f89d743478fdf5b22d --- /dev/null +++ b/mmpretrain/datasets/sun397.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import SUN397_CATEGORIES + + +@DATASETS.register_module() +class SUN397(BaseDataset): + """The SUN397 Dataset. + + Support the `SUN397 Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + SUN397 dataset directory: :: + + SUN397 + ├── SUN397 + │ ├── a + │ │ ├── abbey + │ | | ├── sun_aaalbzqrimafwbiv.jpg + │ | | └── ... + │ │ ├── airplane_cabin + │ | | ├── sun_aadqdkqaslqqoblu.jpg + │ | | └── ... + │ | └── ... + │ ├── b + │ │ └── ... + │ ├── c + │ │ └── ... + │ └── ... + └── Partitions + ├── ClassName.txt + ├── Training_01.txt + ├── Testing_01.txt + └── ... + + Args: + data_root (str): The root directory for Stanford Cars dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import SUN397 + >>> train_dataset = SUN397(data_root='data/SUN397', split='train') + >>> train_dataset + Dataset SUN397 + Number of samples: 19850 + Number of categories: 397 + Root of dataset: data/SUN397 + >>> test_dataset = SUN397(data_root='data/SUN397', split='test') + >>> test_dataset + Dataset SUN397 + Number of samples: 19850 + Number of categories: 397 + Root of dataset: data/SUN397 + + **Note that some images are not a jpg file although the name ends with ".jpg". + The backend of SUN397 should be "pillow" as below to read these images properly,** + + .. code-block:: python + + pipeline = [ + dict(type='LoadImageFromFile', imdecode_backend='pillow'), + dict(type='RandomResizedCrop', scale=224), + dict(type='PackInputs') + ] + """ # noqa: E501 + + METAINFO = {'classes': SUN397_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'train': + ann_file = self.backend.join_path('Partitions', 'Training_01.txt') + else: + ann_file = self.backend.join_path('Partitions', 'Testing_01.txt') + + data_prefix = 'SUN397' + test_mode = split == 'test' + + super(SUN397, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + img_path = self.backend.join_path(self.img_prefix, pair[1:]) + items = pair.split('/') + class_name = '_'.join(items[2:-1]) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def __getitem__(self, idx: int) -> dict: + try: + return super().__getitem__(idx) + except AttributeError: + raise RuntimeError( + 'Some images in the SUN397 dataset are not a jpg file ' + 'although the name ends with ".jpg". The backend of SUN397 ' + 'should be "pillow" to read these images properly.') + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/textvqa.py b/mmpretrain/datasets/textvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..48a82b45ef1a4cc0bad2ab45b32b8ba8d28b2a60 --- /dev/null +++ b/mmpretrain/datasets/textvqa.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class TextVQA(BaseDataset): + """TextVQA dataset. + + val image: + https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip + test image: + https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip + val json: + https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json + test json: + https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json + + folder structure: + data/textvqa + ├── annotations + │ ├── TextVQA_0.5.1_test.json + │ └── TextVQA_0.5.1_val.json + └── images + ├── test_images + └── train_images + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file)['data'] + + data_list = [] + + for ann in annotations: + + # ann example + # { + # 'question': 'what is the brand of...is camera?', + # 'image_id': '003a8ae2ef43b901', + # 'image_classes': [ + # 'Cassette deck', 'Printer', ... + # ], + # 'flickr_original_url': 'https://farm2.static...04a6_o.jpg', + # 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg', + # 'image_width': 1024, + # 'image_height': 664, + # 'answers': [ + # 'nous les gosses', + # 'dakota', + # 'clos culombu', + # 'dakota digital' ... + # ], + # 'question_tokens': + # ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'], + # 'question_id': 34602, + # 'set_name': 'val' + # } + + data_info = dict(question=ann['question']) + data_info['question_id'] = ann['question_id'] + data_info['image_id'] = ann['image_id'] + + img_path = mmengine.join_path(self.data_prefix['img_path'], + ann['image_id'] + '.jpg') + data_info['img_path'] = img_path + + data_info['question_id'] = ann['question_id'] + + if 'answers' in ann: + answers = [item for item in ann.pop('answers')] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cda99d59c9b147c1842adbffa6bb215f657c33c --- /dev/null +++ b/mmpretrain/datasets/transforms/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize, + RandomFlip, RandomGrayscale, RandomResize, Resize) + +from mmpretrain.registry import TRANSFORMS +from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform, + Brightness, ColorTransform, Contrast, Cutout, + Equalize, GaussianBlur, Invert, Posterize, + RandAugment, Rotate, Sharpness, Shear, Solarize, + SolarizeAdd, Translate) +from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs, + PILToNumpy, Transpose) +from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption, + ColorJitter, EfficientNetCenterCrop, + EfficientNetRandomCrop, Lighting, RandomCrop, + RandomErasing, RandomResizedCrop, + RandomResizedCropAndInterpolationWithTwoPic, + RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator) +from .utils import get_transform_idx, remove_transform +from .wrappers import ApplyToList, MultiView + +for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip, + RandomGrayscale, RandomResize, Resize): + TRANSFORMS.register_module(module=t) + +__all__ = [ + 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop', + 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert', + 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', + 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', + 'PackInputs', 'Albumentations', 'EfficientNetRandomCrop', + 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', + 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator', + 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize', + 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView', + 'ApplyToList', 'CleanCaption', 'RandomTranslatePad', + 'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx', + 'remove_transform' +] diff --git a/mmpretrain/datasets/transforms/__pycache__/__init__.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdfacfafafee9f52afaf2f31b9b44995cfb4d14f Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/datasets/transforms/__pycache__/auto_augment.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/auto_augment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ad2b3fb23e081a2d5a876bc1df9fee7c5ac056a Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/auto_augment.cpython-38.pyc differ diff --git a/mmpretrain/datasets/transforms/__pycache__/formatting.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/formatting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e0621832bc6aa3de7af7d311bf6cfe62db2b264 Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/formatting.cpython-38.pyc differ diff --git a/mmpretrain/datasets/transforms/__pycache__/processing.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/processing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df804deb0c6b62eeff44e650497d8c98fb3a207 Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/processing.cpython-38.pyc differ diff --git a/mmpretrain/datasets/transforms/__pycache__/utils.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55796fbf9d0df64784b71a70c978c374036a2feb Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpretrain/datasets/transforms/__pycache__/wrappers.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/wrappers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2bd62c3a4cc8ce1471d00aa950c317883b8ef4 Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/wrappers.cpython-38.pyc differ diff --git a/mmpretrain/datasets/transforms/auto_augment.py b/mmpretrain/datasets/transforms/auto_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..03b057b850a4fd797f8f5c0672f60c6c20e44273 --- /dev/null +++ b/mmpretrain/datasets/transforms/auto_augment.py @@ -0,0 +1,1244 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from copy import deepcopy +from math import ceil +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform, Compose, RandomChoice +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_list_of, is_seq_of +from PIL import Image, ImageFilter + +from mmpretrain.registry import TRANSFORMS + + +def merge_hparams(policy: dict, hparams: dict) -> dict: + """Merge hyperparameters into policy config. + + Only merge partial hyperparameters required of the policy. + + Args: + policy (dict): Original policy config dict. + hparams (dict): Hyperparameters need to be merged. + + Returns: + dict: Policy config dict after adding ``hparams``. + """ + policy = deepcopy(policy) + op = TRANSFORMS.get(policy['type']) + assert op is not None, f'Invalid policy type "{policy["type"]}".' + + op_args = inspect.getfullargspec(op.__init__).args + for key, value in hparams.items(): + if key in op_args and key not in policy: + policy[key] = value + return policy + + +@TRANSFORMS.register_module() +class AutoAugment(RandomChoice): + """Auto augmentation. + + This data augmentation is proposed in `AutoAugment: Learning Augmentation + Policies from Data `_. + + Args: + policies (str | list[list[dict]]): The policies of auto augmentation. + If string, use preset policies collection like "imagenet". If list, + Each item is a sub policies, composed by several augmentation + policy dicts. When AutoAugment is called, a random sub policies in + ``policies`` will be selected to augment images. + hparams (dict): Configs of hyperparameters. Hyperparameters will be + used in policies that require these arguments if these arguments + are not set in policy dicts. Defaults to ``dict(pad_val=128)``. + + .. admonition:: Available preset policies + + - ``"imagenet"``: Policy for ImageNet, come from + `DeepVoltaire/AutoAugment`_ + + .. _DeepVoltaire/AutoAugment: https://github.com/DeepVoltaire/AutoAugment + """ + + def __init__(self, + policies: Union[str, List[List[dict]]], + hparams: dict = dict(pad_val=128)): + if isinstance(policies, str): + assert policies in AUTOAUG_POLICIES, 'Invalid policies, ' \ + f'please choose from {list(AUTOAUG_POLICIES.keys())}.' + policies = AUTOAUG_POLICIES[policies] + self.hparams = hparams + self.policies = [[merge_hparams(t, hparams) for t in sub] + for sub in policies] + transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies] + + super().__init__(transforms=transforms) + + def __repr__(self) -> str: + policies_str = '' + for sub in self.policies: + policies_str += '\n ' + ', \t'.join([t['type'] for t in sub]) + + repr_str = self.__class__.__name__ + repr_str += f'(policies:{policies_str}\n)' + return repr_str + + +@TRANSFORMS.register_module() +class RandAugment(BaseTransform): + r"""Random augmentation. + + This data augmentation is proposed in `RandAugment: Practical automated + data augmentation with a reduced search space + `_. + + Args: + policies (str | list[dict]): The policies of random augmentation. + If string, use preset policies collection like "timm_increasing". + If list, each item is one specific augmentation policy dict. + The policy dict shall should have these keys: + + - ``type`` (str), The type of augmentation. + - ``magnitude_range`` (Sequence[number], optional): For those + augmentation have magnitude, you need to specify the magnitude + level mapping range. For example, assume ``total_level`` is 10, + ``magnitude_level=3`` specify magnitude is 3 if + ``magnitude_range=(0, 10)`` while specify magnitude is 7 if + ``magnitude_range=(10, 0)``. + - other keyword arguments of the augmentation. + + num_policies (int): Number of policies to select from policies each + time. + magnitude_level (int | float): Magnitude level for all the augmentation + selected. + magnitude_std (Number | str): Deviation of magnitude noise applied. + + - If positive number, the magnitude obeys normal distribution + :math:`\mathcal{N}(magnitude_level, magnitude_std)`. + - If 0 or negative number, magnitude remains unchanged. + - If str "inf", the magnitude obeys uniform distribution + :math:`Uniform(min, magnitude)`. + total_level (int | float): Total level for the magnitude. Defaults to + 10. + hparams (dict): Configs of hyperparameters. Hyperparameters will be + used in policies that require these arguments if these arguments + are not set in policy dicts. Defaults to ``dict(pad_val=128)``. + + .. admonition:: Available preset policies + + - ``"timm_increasing"``: The ``_RAND_INCREASING_TRANSFORMS`` policy + from `timm`_ + + .. _timm: https://github.com/rwightman/pytorch-image-models + + Examples: + + To use "timm-increasing" policies collection, select two policies every + time, and magnitude_level of every policy is 6 (total is 10 by default) + + >>> import numpy as np + >>> from mmpretrain.datasets import RandAugment + >>> transform = RandAugment( + ... policies='timm_increasing', + ... num_policies=2, + ... magnitude_level=6, + ... ) + >>> data = {'img': np.random.randint(0, 256, (224, 224, 3))} + >>> results = transform(data) + >>> print(results['img'].shape) + (224, 224, 3) + + If you want the ``magnitude_level`` randomly changes every time, you + can use ``magnitude_std`` to specify the random distribution. For + example, a normal distribution :math:`\mathcal{N}(6, 0.5)`. + + >>> transform = RandAugment( + ... policies='timm_increasing', + ... num_policies=2, + ... magnitude_level=6, + ... magnitude_std=0.5, + ... ) + + You can also use your own policies: + + >>> policies = [ + ... dict(type='AutoContrast'), + ... dict(type='Rotate', magnitude_range=(0, 30)), + ... dict(type='ColorTransform', magnitude_range=(0, 0.9)), + ... ] + >>> transform = RandAugment( + ... policies=policies, + ... num_policies=2, + ... magnitude_level=6 + ... ) + + Note: + ``magnitude_std`` will introduce some randomness to policy, modified by + https://github.com/rwightman/pytorch-image-models. + + When magnitude_std=0, we calculate the magnitude as follows: + + .. math:: + \text{magnitude} = \frac{\text{magnitude_level}} + {\text{totallevel}} \times (\text{val2} - \text{val1}) + + \text{val1} + """ + + def __init__(self, + policies: Union[str, List[dict]], + num_policies: int, + magnitude_level: int, + magnitude_std: Union[Number, str] = 0., + total_level: int = 10, + hparams: dict = dict(pad_val=128)): + if isinstance(policies, str): + assert policies in RANDAUG_POLICIES, 'Invalid policies, ' \ + f'please choose from {list(RANDAUG_POLICIES.keys())}.' + policies = RANDAUG_POLICIES[policies] + + assert is_list_of(policies, dict), 'policies must be a list of dict.' + + assert isinstance(magnitude_std, (Number, str)), \ + '`magnitude_std` must be of number or str type, ' \ + f'got {type(magnitude_std)} instead.' + if isinstance(magnitude_std, str): + assert magnitude_std == 'inf', \ + '`magnitude_std` must be of number or "inf", ' \ + f'got "{magnitude_std}" instead.' + + assert num_policies > 0, 'num_policies must be greater than 0.' + assert magnitude_level >= 0, 'magnitude_level must be no less than 0.' + assert total_level > 0, 'total_level must be greater than 0.' + + self.num_policies = num_policies + self.magnitude_level = magnitude_level + self.magnitude_std = magnitude_std + self.total_level = total_level + self.hparams = hparams + self.policies = [] + self.transforms = [] + + randaug_cfg = dict( + magnitude_level=magnitude_level, + total_level=total_level, + magnitude_std=magnitude_std) + + for policy in policies: + self._check_policy(policy) + policy = merge_hparams(policy, hparams) + policy.pop('magnitude_key', None) # For backward compatibility + if 'magnitude_range' in policy: + policy.update(randaug_cfg) + self.policies.append(policy) + self.transforms.append(TRANSFORMS.build(policy)) + + def __iter__(self): + """Iterate all transforms.""" + return iter(self.transforms) + + def _check_policy(self, policy): + """Check whether the sub-policy dict is available.""" + assert isinstance(policy, dict) and 'type' in policy, \ + 'Each policy must be a dict with key "type".' + type_name = policy['type'] + + if 'magnitude_range' in policy: + magnitude_range = policy['magnitude_range'] + assert is_seq_of(magnitude_range, Number), \ + f'`magnitude_range` of RandAugment policy {type_name} ' \ + 'should be a sequence with two numbers.' + + @cache_randomness + def random_policy_indices(self) -> np.ndarray: + """Return the random chosen transform indices.""" + indices = np.arange(len(self.policies)) + return np.random.choice(indices, size=self.num_policies).tolist() + + def transform(self, results: dict) -> Optional[dict]: + """Randomly choose a sub-policy to apply.""" + + chosen_policies = [ + self.transforms[i] for i in self.random_policy_indices() + ] + + sub_pipeline = Compose(chosen_policies) + return sub_pipeline(results) + + def __repr__(self) -> str: + policies_str = '' + for policy in self.policies: + policies_str += '\n ' + f'{policy["type"]}' + if 'magnitude_range' in policy: + val1, val2 = policy['magnitude_range'] + policies_str += f' ({val1}, {val2})' + + repr_str = self.__class__.__name__ + repr_str += f'(num_policies={self.num_policies}, ' + repr_str += f'magnitude_level={self.magnitude_level}, ' + repr_str += f'total_level={self.total_level}, ' + repr_str += f'policies:{policies_str}\n)' + return repr_str + + +class BaseAugTransform(BaseTransform): + r"""The base class of augmentation transform for RandAugment. + + This class provides several common attributions and methods to support the + magnitude level mapping and magnitude level randomness in + :class:`RandAugment`. + + Args: + magnitude_level (int | float): Magnitude level. + magnitude_range (Sequence[number], optional): For augmentation have + magnitude argument, maybe "magnitude", "angle" or other, you can + specify the magnitude level mapping range to generate the magnitude + argument. For example, assume ``total_level`` is 10, + ``magnitude_level=3`` specify magnitude is 3 if + ``magnitude_range=(0, 10)`` while specify magnitude is 7 if + ``magnitude_range=(10, 0)``. Defaults to None. + magnitude_std (Number | str): Deviation of magnitude noise applied. + + - If positive number, the magnitude obeys normal distribution + :math:`\mathcal{N}(magnitude, magnitude_std)`. + - If 0 or negative number, magnitude remains unchanged. + - If str "inf", the magnitude obeys uniform distribution + :math:`Uniform(min, magnitude)`. + + Defaults to 0. + total_level (int | float): Total level for the magnitude. Defaults to + 10. + prob (float): The probability for performing transformation therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0. + """ + + def __init__(self, + magnitude_level: int = 10, + magnitude_range: Tuple[float, float] = None, + magnitude_std: Union[str, float] = 0., + total_level: int = 10, + prob: float = 0.5, + random_negative_prob: float = 0.5): + self.magnitude_level = magnitude_level + self.magnitude_range = magnitude_range + self.magnitude_std = magnitude_std + self.total_level = total_level + self.prob = prob + self.random_negative_prob = random_negative_prob + + @cache_randomness + def random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.prob + + @cache_randomness + def random_magnitude(self): + """Randomly generate magnitude.""" + magnitude = self.magnitude_level + # if magnitude_std is positive number or 'inf', move + # magnitude_value randomly. + if self.magnitude_std == 'inf': + magnitude = np.random.uniform(0, magnitude) + elif self.magnitude_std > 0: + magnitude = np.random.normal(magnitude, self.magnitude_std) + magnitude = np.clip(magnitude, 0, self.total_level) + + val1, val2 = self.magnitude_range + magnitude = (magnitude / self.total_level) * (val2 - val1) + val1 + return magnitude + + @cache_randomness + def random_negative(self, value): + """Randomly negative the value.""" + if np.random.rand() < self.random_negative_prob: + return -value + else: + return value + + def extra_repr(self): + """Extra repr string when auto-generating magnitude is enabled.""" + if self.magnitude_range is not None: + repr_str = f', magnitude_level={self.magnitude_level}, ' + repr_str += f'magnitude_range={self.magnitude_range}, ' + repr_str += f'magnitude_std={self.magnitude_std}, ' + repr_str += f'total_level={self.total_level}, ' + return repr_str + else: + return '' + + +@TRANSFORMS.register_module() +class Shear(BaseAugTransform): + """Shear images. + + Args: + magnitude (int | float | None): The magnitude used for shear. If None, + generate from ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing shear therefore should be + in range [0, 1]. Defaults to 0.5. + direction (str): The shearing direction. Options are 'horizontal' and + 'vertical'. Defaults to 'horizontal'. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'bicubic'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + direction: str = 'horizontal', + random_negative_prob: float = 0.5, + interpolation: str = 'bicubic', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + assert direction in ('horizontal', 'vertical'), 'direction must be ' \ + f'either "horizontal" or "vertical", got "{direction}" instead.' + self.direction = direction + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_sheared = mmcv.imshear( + img, + magnitude, + direction=self.direction, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_sheared.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'direction={self.direction}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Translate(BaseAugTransform): + """Translate images. + + Args: + magnitude (int | float | None): The magnitude used for translate. Note + that the offset is calculated by magnitude * size in the + corresponding direction. With a magnitude of 1, the whole image + will be moved out of the range. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing translate therefore should + be in range [0, 1]. Defaults to 0.5. + direction (str): The translating direction. Options are 'horizontal' + and 'vertical'. Defaults to 'horizontal'. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + direction: str = 'horizontal', + random_negative_prob: float = 0.5, + interpolation: str = 'nearest', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + assert direction in ('horizontal', 'vertical'), 'direction must be ' \ + f'either "horizontal" or "vertical", got "{direction}" instead.' + self.direction = direction + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + height, width = img.shape[:2] + if self.direction == 'horizontal': + offset = magnitude * width + else: + offset = magnitude * height + img_translated = mmcv.imtranslate( + img, + offset, + direction=self.direction, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_translated.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'direction={self.direction}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Rotate(BaseAugTransform): + """Rotate images. + + Args: + angle (float, optional): The angle used for rotate. Positive values + stand for clockwise rotation. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If None, the center of the image will be used. + Defaults to None. + scale (float): Isotropic scale factor. Defaults to 1.0. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing rotate therefore should be + in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the angle + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + angle: Optional[float] = None, + center: Optional[Tuple[float]] = None, + scale: float = 1.0, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + random_negative_prob: float = 0.5, + interpolation: str = 'nearest', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (angle is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `angle` and `magnitude_range`.' + + self.angle = angle + self.center = center + self.scale = scale + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.angle is not None: + angle = self.random_negative(self.angle) + else: + angle = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_rotated = mmcv.imrotate( + img, + angle, + center=self.center, + scale=self.scale, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_rotated.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(angle={self.angle}, ' + repr_str += f'center={self.center}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class AutoContrast(BaseAugTransform): + """Auto adjust image contrast. + + Args: + prob (float): The probability for performing auto contrast + therefore should be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_contrasted = mmcv.auto_contrast(img) + results['img'] = img_contrasted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Invert(BaseAugTransform): + """Invert images. + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_inverted = mmcv.iminvert(img) + results['img'] = img_inverted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Equalize(BaseAugTransform): + """Equalize the image histogram. + + Args: + prob (float): The probability for performing equalize therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_equalized = mmcv.imequalize(img) + results['img'] = img_equalized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Solarize(BaseAugTransform): + """Solarize images (invert all pixel values above a threshold). + + Args: + thr (int | float | None): The threshold above which the pixels value + will be inverted. If None, generate from ``magnitude_range``, + see :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + thr: Union[int, float, None] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (thr is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `thr` and `magnitude_range`.' + + self.thr = thr + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.thr is not None: + thr = self.thr + else: + thr = self.random_magnitude() + + img = results['img'] + img_solarized = mmcv.solarize(img, thr=thr) + results['img'] = img_solarized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(thr={self.thr}, ' + repr_str += f'prob={self.prob}{self.extra_repr()}))' + return repr_str + + +@TRANSFORMS.register_module() +class SolarizeAdd(BaseAugTransform): + """SolarizeAdd images (add a certain value to pixels below a threshold). + + Args: + magnitude (int | float | None): The value to be added to pixels below + the thr. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + thr (int | float): The threshold below which the pixels value will be + adjusted. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + thr: Union[int, float] = 128, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + assert isinstance(thr, (int, float)), 'The thr type must '\ + f'be int or float, but got {type(thr)} instead.' + self.thr = thr + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.magnitude + else: + magnitude = self.random_magnitude() + + img = results['img'] + img_solarized = np.where(img < self.thr, + np.minimum(img + magnitude, 255), img) + results['img'] = img_solarized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'thr={self.thr}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Posterize(BaseAugTransform): + """Posterize images (reduce the number of bits for each color channel). + + Args: + bits (int, optional): Number of bits for each pixel in the output img, + which should be less or equal to 8. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + prob (float): The probability for posterizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + bits: Optional[int] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (bits is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `bits` and `magnitude_range`.' + + if bits is not None: + assert bits <= 8, \ + f'The bits must be less than 8, got {bits} instead.' + self.bits = bits + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.bits is not None: + bits = self.bits + else: + bits = self.random_magnitude() + + # To align timm version, we need to round up to integer here. + bits = ceil(bits) + + img = results['img'] + img_posterized = mmcv.posterize(img, bits=bits) + results['img'] = img_posterized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(bits={self.bits}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Contrast(BaseAugTransform): + """Adjust images contrast. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + contrast. A positive magnitude would enhance the contrast and + a negative magnitude would make the image grayer. A magnitude=0 + gives the origin img. If None, generate from ``magnitude_range``, + see :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing contrast adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude) + results['img'] = img_contrasted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class ColorTransform(BaseAugTransform): + """Adjust images color balance. + + Args: + magnitude (int | float | None): The magnitude used for color transform. + A positive magnitude would enhance the color and a negative + magnitude would make the image grayer. A magnitude=0 gives the + origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing ColorTransform therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude) + results['img'] = img_color_adjusted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Brightness(BaseAugTransform): + """Adjust images brightness. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + brightness. A positive magnitude would enhance the brightness and a + negative magnitude would make the image darker. A magnitude=0 gives + the origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing brightness adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude) + results['img'] = img_brightened.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Sharpness(BaseAugTransform): + """Adjust images sharpness. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + sharpness. A positive magnitude would enhance the sharpness and a + negative magnitude would make the image bulr. A magnitude=0 gives + the origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing sharpness adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude) + results['img'] = img_sharpened.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Cutout(BaseAugTransform): + """Cutout images. + + Args: + shape (int | tuple(int) | None): Expected cutout shape (h, w). + If given as a single value, the value will be used for both h and + w. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If it is a sequence, it must have the same length with the image + channels. Defaults to 128. + prob (float): The probability for performing cutout therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + shape: Union[int, Tuple[int], None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (shape is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `shape` and `magnitude_range`.' + + self.shape = shape + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.shape is not None: + shape = self.shape + else: + shape = int(self.random_magnitude()) + + img = results['img'] + img_cutout = mmcv.cutout(img, shape, pad_val=self.pad_val) + results['img'] = img_cutout.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(shape={self.shape}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class GaussianBlur(BaseAugTransform): + """Gaussian blur images. + + Args: + radius (int, float, optional): The blur radius. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + prob (float): The probability for posterizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + radius: Union[int, float, None] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (radius is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `radius` and `magnitude_range`.' + + self.radius = radius + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.radius is not None: + radius = self.radius + else: + radius = self.random_magnitude() + + img = results['img'] + pil_img = Image.fromarray(img) + pil_img.filter(ImageFilter.GaussianBlur(radius=radius)) + results['img'] = np.array(pil_img, dtype=img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(radius={self.radius}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +# yapf: disable +# flake8: noqa +AUTOAUG_POLICIES = { + # Policy for ImageNet, refers to + # https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py + 'imagenet': [ + [dict(type='Posterize', bits=4, prob=0.4), dict(type='Rotate', angle=30., prob=0.6)], + [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)], + [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)], + [dict(type='Posterize', bits=5, prob=0.6), dict(type='Posterize', bits=5, prob=0.6)], + [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)], + [dict(type='Equalize', prob=0.4), dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)], + [dict(type='Solarize', thr=256 / 9 * 6, prob=0.6), dict(type='Equalize', prob=0.6)], + [dict(type='Posterize', bits=6, prob=0.8), dict(type='Equalize', prob=1.)], + [dict(type='Rotate', angle=10., prob=0.2), dict(type='Solarize', thr=256 / 9, prob=0.6)], + [dict(type='Equalize', prob=0.6), dict(type='Posterize', bits=5, prob=0.4)], + [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0., prob=0.4)], + [dict(type='Rotate', angle=30., prob=0.4), dict(type='Equalize', prob=0.6)], + [dict(type='Equalize', prob=0.0), dict(type='Equalize', prob=0.8)], + [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)], + [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0.2, prob=1.)], + [dict(type='ColorTransform', magnitude=0.8, prob=0.8), dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)], + [dict(type='Sharpness', magnitude=0.7, prob=0.4), dict(type='Invert', prob=0.6)], + [dict(type='Shear', magnitude=0.3 / 9 * 5, prob=0.6, direction='horizontal'), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0., prob=0.4), dict(type='Equalize', prob=0.6)], + [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)], + [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)], + [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)], + [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)], + ], +} + +RANDAUG_POLICIES = { + # Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models + 'timm_increasing': [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Invert'), + dict(type='Rotate', magnitude_range=(0, 30)), + dict(type='Posterize', magnitude_range=(4, 0)), + dict(type='Solarize', magnitude_range=(256, 0)), + dict(type='SolarizeAdd', magnitude_range=(0, 110)), + dict(type='ColorTransform', magnitude_range=(0, 0.9)), + dict(type='Contrast', magnitude_range=(0, 0.9)), + dict(type='Brightness', magnitude_range=(0, 0.9)), + dict(type='Sharpness', magnitude_range=(0, 0.9)), + dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'), + dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'), + dict(type='Translate', magnitude_range=(0, 0.45), direction='horizontal'), + dict(type='Translate', magnitude_range=(0, 0.45), direction='vertical'), + ], + 'simple_increasing': [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Rotate', magnitude_range=(0, 30)), + dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'), + dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'), + ], +} diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d331636a883ce602e419e0867aea7b513b4d87 --- /dev/null +++ b/mmpretrain/datasets/transforms/formatting.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from collections.abc import Sequence + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as F +from mmcv.transforms import BaseTransform +from mmengine.utils import is_str +from PIL import Image + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample, MultiTaskDataSample + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + """ + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError( + f'Type {type(data)} cannot be converted to tensor.' + 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' + '`Sequence`, `int` and `float`') + + +@TRANSFORMS.register_module() +class PackInputs(BaseTransform): + """Pack the inputs data. + + **Required Keys:** + + - ``input_key`` + - ``*algorithm_keys`` + - ``*meta_keys`` + + **Deleted Keys:** + + All other keys in the dict. + + **Added Keys:** + + - inputs (:obj:`torch.Tensor`): The forward data of models. + - data_samples (:obj:`~mmpretrain.structures.DataSample`): The + annotation info of the sample. + + Args: + input_key (str): The key of element to feed into the model forwarding. + Defaults to 'img'. + algorithm_keys (Sequence[str]): The keys of custom elements to be used + in the algorithm. Defaults to an empty tuple. + meta_keys (Sequence[str]): The keys of meta information to be saved in + the data sample. Defaults to :attr:`PackInputs.DEFAULT_META_KEYS`. + + .. admonition:: Default algorithm keys + + Besides the specified ``algorithm_keys``, we will set some default keys + into the output data sample and do some formatting. Therefore, you + don't need to set these keys in the ``algorithm_keys``. + + - ``gt_label``: The ground-truth label. The value will be converted + into a 1-D tensor. + - ``gt_score``: The ground-truth score. The value will be converted + into a 1-D tensor. + - ``mask``: The mask for some self-supervise tasks. The value will + be converted into a tensor. + + .. admonition:: Default meta keys + + - ``sample_idx``: The id of the image sample. + - ``img_path``: The path to the image file. + - ``ori_shape``: The original shape of the image as a tuple (H, W). + - ``img_shape``: The shape of the image after the pipeline as a + tuple (H, W). + - ``scale_factor``: The scale factor between the resized image and + the original image. + - ``flip``: A boolean indicating if image flip transform was used. + - ``flip_direction``: The flipping direction. + """ + + DEFAULT_META_KEYS = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction') + + def __init__(self, + input_key='img', + algorithm_keys=(), + meta_keys=DEFAULT_META_KEYS): + self.input_key = input_key + self.algorithm_keys = algorithm_keys + self.meta_keys = meta_keys + + @staticmethod + def format_input(input_): + if isinstance(input_, list): + return [PackInputs.format_input(item) for item in input_] + elif isinstance(input_, np.ndarray): + if input_.ndim == 2: # For grayscale image. + input_ = np.expand_dims(input_, -1) + if input_.ndim == 3 and not input_.flags.c_contiguous: + input_ = np.ascontiguousarray(input_.transpose(2, 0, 1)) + input_ = to_tensor(input_) + elif input_.ndim == 3: + # convert to tensor first to accelerate, see + # https://github.com/open-mmlab/mmdetection/pull/9533 + input_ = to_tensor(input_).permute(2, 0, 1).contiguous() + else: + # convert input with other shape to tensor without permute, + # like video input (num_crops, C, T, H, W). + input_ = to_tensor(input_) + elif isinstance(input_, Image.Image): + input_ = F.pil_to_tensor(input_) + elif not isinstance(input_, torch.Tensor): + raise TypeError(f'Unsupported input type {type(input_)}.') + + return input_ + + def transform(self, results: dict) -> dict: + """Method to pack the input data.""" + + packed_results = dict() + if self.input_key in results: + input_ = results[self.input_key] + packed_results['inputs'] = self.format_input(input_) + + data_sample = DataSample() + + # Set default keys + if 'gt_label' in results: + data_sample.set_gt_label(results['gt_label']) + if 'gt_score' in results: + data_sample.set_gt_score(results['gt_score']) + if 'mask' in results: + data_sample.set_mask(results['mask']) + + # Set custom algorithm keys + for key in self.algorithm_keys: + if key in results: + data_sample.set_field(results[key], key) + + # Set meta keys + for key in self.meta_keys: + if key in results: + data_sample.set_field(results[key], key, field_type='metainfo') + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(input_key='{self.input_key}', " + repr_str += f'algorithm_keys={self.algorithm_keys}, ' + repr_str += f'meta_keys={self.meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class PackMultiTaskInputs(BaseTransform): + """Convert all image labels of multi-task dataset to a dict of tensor. + + Args: + multi_task_fields (Sequence[str]): + input_key (str): + task_handlers (dict): + """ + + def __init__(self, + multi_task_fields, + input_key='img', + task_handlers=dict()): + self.multi_task_fields = multi_task_fields + self.input_key = input_key + self.task_handlers = defaultdict(PackInputs) + for task_name, task_handler in task_handlers.items(): + self.task_handlers[task_name] = TRANSFORMS.build(task_handler) + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3}, + 'img': array([[[ 0, 0, 0]) + """ + packed_results = dict() + results = results.copy() + + if self.input_key in results: + input_ = results[self.input_key] + packed_results['inputs'] = PackInputs.format_input(input_) + + task_results = defaultdict(dict) + for field in self.multi_task_fields: + if field in results: + value = results.pop(field) + for k, v in value.items(): + task_results[k].update({field: v}) + + data_sample = MultiTaskDataSample() + for task_name, task_result in task_results.items(): + task_handler = self.task_handlers[task_name] + task_pack_result = task_handler({**results, **task_result}) + data_sample.set_field(task_pack_result['data_samples'], task_name) + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self): + repr = self.__class__.__name__ + task_handlers = ', '.join( + f"'{name}': {handler.__class__.__name__}" + for name, handler in self.task_handlers.items()) + repr += f'(multi_task_fields={self.multi_task_fields}, ' + repr += f"input_key='{self.input_key}', " + repr += f'task_handlers={{{task_handlers}}})' + return repr + + +@TRANSFORMS.register_module() +class Transpose(BaseTransform): + """Transpose numpy array. + + **Required Keys:** + + - ``*keys`` + + **Modified Keys:** + + - ``*keys`` + + Args: + keys (List[str]): The fields to convert to tensor. + order (List[int]): The output dimensions order. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def transform(self, results): + """Method to transpose array.""" + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL')) +class NumpyToPIL(BaseTransform): + """Convert the image from OpenCV format to :obj:`PIL.Image.Image`. + + **Required Keys:** + + - ``img`` + + **Modified Keys:** + + - ``img`` + + Args: + to_rgb (bool): Whether to convert img to rgb. Defaults to True. + """ + + def __init__(self, to_rgb: bool = False) -> None: + self.to_rgb = to_rgb + + def transform(self, results: dict) -> dict: + """Method to convert images to :obj:`PIL.Image.Image`.""" + img = results['img'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img + + results['img'] = Image.fromarray(img) + return results + + def __repr__(self) -> str: + return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' + + +@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy')) +class PILToNumpy(BaseTransform): + """Convert img to :obj:`numpy.ndarray`. + + **Required Keys:** + + - ``img`` + + **Modified Keys:** + + - ``img`` + + Args: + to_bgr (bool): Whether to convert img to rgb. Defaults to True. + dtype (str, optional): The dtype of the converted numpy array. + Defaults to None. + """ + + def __init__(self, to_bgr: bool = False, dtype=None) -> None: + self.to_bgr = to_bgr + self.dtype = dtype + + def transform(self, results: dict) -> dict: + """Method to convert img to :obj:`numpy.ndarray`.""" + img = np.array(results['img'], dtype=self.dtype) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img + + results['img'] = img + return results + + def __repr__(self) -> str: + return self.__class__.__name__ + \ + f'(to_bgr={self.to_bgr}, dtype={self.dtype})' + + +@TRANSFORMS.register_module() +class Collect(BaseTransform): + """Collect and only reserve the specified fields. + + **Required Keys:** + + - ``*keys`` + + **Deleted Keys:** + + All keys except those in the argument ``*keys``. + + Args: + keys (Sequence[str]): The keys of the fields to be collected. + """ + + def __init__(self, keys): + self.keys = keys + + def transform(self, results): + data = {} + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc04da74ec8032690f1a59126e09a323e9a0036 --- /dev/null +++ b/mmpretrain/datasets/transforms/processing.py @@ -0,0 +1,1742 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import math +import numbers +import re +import string +from enum import EnumMeta +from numbers import Number +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import mmengine +import numpy as np +import torchvision +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness +from torchvision.transforms.transforms import InterpolationMode + +from mmpretrain.registry import TRANSFORMS + +try: + import albumentations +except ImportError: + albumentations = None + + +def _str_to_torch_dtype(t: str): + """mapping str format dtype to torch.dtype.""" + import torch # noqa: F401,F403 + return eval(f'torch.{t}') + + +def _interpolation_modes_from_str(t: str): + """mapping str format to Interpolation.""" + t = t.lower() + inverse_modes_mapping = { + 'nearest': InterpolationMode.NEAREST, + 'bilinear': InterpolationMode.BILINEAR, + 'bicubic': InterpolationMode.BICUBIC, + 'box': InterpolationMode.BOX, + 'hammimg': InterpolationMode.HAMMING, + 'lanczos': InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[t] + + +class TorchVisonTransformWrapper: + + def __init__(self, transform, *args, **kwargs): + if 'interpolation' in kwargs and isinstance(kwargs['interpolation'], + str): + kwargs['interpolation'] = _interpolation_modes_from_str( + kwargs['interpolation']) + if 'dtype' in kwargs and isinstance(kwargs['dtype'], str): + kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype']) + self.t = transform(*args, **kwargs) + + def __call__(self, results): + results['img'] = self.t(results['img']) + return results + + def __repr__(self) -> str: + return f'TorchVision{repr(self.t)}' + + +def register_vision_transforms() -> List[str]: + """Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS`` + registry. + + Returns: + List[str]: A list of registered transforms' name. + """ + vision_transforms = [] + for module_name in dir(torchvision.transforms): + if not re.match('[A-Z]', module_name): + # must startswith a capital letter + continue + _transform = getattr(torchvision.transforms, module_name) + if inspect.isclass(_transform) and callable( + _transform) and not isinstance(_transform, (EnumMeta)): + from functools import partial + TRANSFORMS.register_module( + module=partial( + TorchVisonTransformWrapper, transform=_transform), + name=f'torchvision/{module_name}') + vision_transforms.append(f'torchvision/{module_name}') + return vision_transforms + + +# register all the transforms in torchvision by using a transform wrapper +VISION_TRANSFORMS = register_vision_transforms() + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Crop the given Image at a random location. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + crop_size (int | Sequence): Desired output size of the crop. If + crop_size is an int instead of sequence like (h, w), a square crop + (crop_size, crop_size) is made. + padding (int | Sequence, optional): Optional padding on each border + of the image. If a sequence of length 4 is provided, it is used to + pad left, top, right, bottom borders respectively. If a sequence + of length 2 is provided, it is used to pad left/right, top/bottom + borders, respectively. Default: None, which means no padding. + pad_if_needed (bool): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + Default: False. + pad_val (Number | Sequence[Number]): Pixel pad_val value for constant + fill. If a tuple of length 3, it is used to pad_val R, G, B + channels respectively. Default: 0. + padding_mode (str): Type of padding. Defaults to "constant". Should + be one of the following: + + - ``constant``: Pads with a constant value, this value is specified + with pad_val. + - ``edge``: pads with the last value at the edge of the image. + - ``reflect``: Pads with reflection of image without repeating the + last value on the edge. For example, padding [1, 2, 3, 4] + with 2 elements on both sides in reflect mode will result + in [3, 2, 1, 2, 3, 4, 3, 2]. + - ``symmetric``: Pads with reflection of image repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with + 2 elements on both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3]. + """ + + def __init__(self, + crop_size: Union[Sequence, int], + padding: Optional[Union[Sequence, int]] = None, + pad_if_needed: bool = False, + pad_val: Union[Number, Sequence[Number]] = 0, + padding_mode: str = 'constant'): + if isinstance(crop_size, Sequence): + assert len(crop_size) == 2 + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + else: + assert crop_size > 0 + self.crop_size = (crop_size, crop_size) + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + self.padding = padding + self.pad_if_needed = pad_if_needed + self.pad_val = pad_val + self.padding_mode = padding_mode + + @cache_randomness + def rand_crop_params(self, img: np.ndarray): + """Get parameters for ``crop`` for a random crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to ``crop`` for random crop. + """ + h, w = img.shape[:2] + target_h, target_w = self.crop_size + if w == target_w and h == target_h: + return 0, 0, h, w + elif w < target_w or h < target_h: + target_w = min(w, target_w) + target_h = min(w, target_h) + + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + if self.padding is not None: + img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val) + + # pad img if needed + if self.pad_if_needed: + h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2) + w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2) + + img = mmcv.impad( + img, + padding=(w_pad, h_pad, w_pad, h_pad), + pad_val=self.pad_val, + padding_mode=self.padding_mode) + + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + np.array([ + offset_w, + offset_h, + offset_w + target_w - 1, + offset_h + target_h - 1, + ])) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', padding={self.padding}' + repr_str += f', pad_if_needed={self.pad_if_needed}' + repr_str += f', pad_val={self.pad_val}' + repr_str += f', padding_mode={self.padding_mode})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomResizedCrop(BaseTransform): + """Crop the given image to random scale and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a + random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio + is made. This crop is finally resized to given size. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + scale (sequence | int): Desired output scale of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). + max_attempts (int): Maximum number of attempts before falling back to + Central Crop. Defaults to 10. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to + 'bilinear'. + backend (str): The image resize backend type, accepted values are + 'cv2' and 'pillow'. Defaults to 'cv2'. + """ + + def __init__(self, + scale: Union[Sequence, int], + crop_ratio_range: Tuple[float, float] = (0.08, 1.0), + aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.), + max_attempts: int = 10, + interpolation: str = 'bilinear', + backend: str = 'cv2') -> None: + if isinstance(scale, Sequence): + assert len(scale) == 2 + assert scale[0] > 0 and scale[1] > 0 + self.scale = scale + else: + assert scale > 0 + self.scale = (scale, scale) + if (crop_ratio_range[0] > crop_ratio_range[1]) or ( + aspect_ratio_range[0] > aspect_ratio_range[1]): + raise ValueError( + 'range should be of kind (min, max). ' + f'But received crop_ratio_range {crop_ratio_range} ' + f'and aspect_ratio_range {aspect_ratio_range}.') + assert isinstance(max_attempts, int) and max_attempts >= 0, \ + 'max_attempts mush be int and no less than 0.' + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') + + self.crop_ratio_range = crop_ratio_range + self.aspect_ratio_range = aspect_ratio_range + self.max_attempts = max_attempts + self.interpolation = interpolation + self.backend = backend + + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. + """ + h, w = img.shape[:2] + area = h * w + + for _ in range(self.max_attempts): + target_area = np.random.uniform(*self.crop_ratio_range) * area + log_ratio = (math.log(self.aspect_ratio_range[0]), + math.log(self.aspect_ratio_range[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + target_w = int(round(math.sqrt(target_area * aspect_ratio))) + target_h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < target_w <= w and 0 < target_h <= h: + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + # Fallback to central crop + in_ratio = float(w) / float(h) + if in_ratio < min(self.aspect_ratio_range): + target_w = w + target_h = int(round(target_w / min(self.aspect_ratio_range))) + elif in_ratio > max(self.aspect_ratio_range): + target_h = h + target_w = int(round(target_h * max(self.aspect_ratio_range))) + else: # whole image + target_w = w + target_h = h + offset_h = (h - target_h) // 2 + offset_w = (w - target_w) // 2 + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly resized cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + target_w - 1, + offset_h + target_h - 1 + ])) + img = mmcv.imresize( + img, + tuple(self.scale[::-1]), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(scale={self.scale}' + repr_str += ', crop_ratio_range=' + repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}' + repr_str += ', aspect_ratio_range=' + repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}' + repr_str += f', max_attempts={self.max_attempts}' + repr_str += f', interpolation={self.interpolation}' + repr_str += f', backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class EfficientNetRandomCrop(RandomResizedCrop): + """EfficientNet style RandomResizedCrop. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + scale (int): Desired output scale of the crop. Only int size is + accepted, a square crop (size, size) is made. + min_covered (Number): Minimum ratio of the cropped area to the original + area. Defaults to 0.1. + crop_padding (int): The crop padding parameter in efficientnet style + center crop. Defaults to 32. + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). + max_attempts (int): Maximum number of attempts before falling back to + Central Crop. Defaults to 10. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to + 'bicubic'. + backend (str): The image resize backend type, accepted values are + 'cv2' and 'pillow'. Defaults to 'cv2'. + """ + + def __init__(self, + scale: int, + min_covered: float = 0.1, + crop_padding: int = 32, + interpolation: str = 'bicubic', + **kwarg): + assert isinstance(scale, int) + super().__init__(scale, interpolation=interpolation, **kwarg) + assert min_covered >= 0, 'min_covered should be no less than 0.' + assert crop_padding >= 0, 'crop_padding should be no less than 0.' + + self.min_covered = min_covered + self.crop_padding = crop_padding + + # https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. + """ + h, w = img.shape[:2] + area = h * w + min_target_area = self.crop_ratio_range[0] * area + max_target_area = self.crop_ratio_range[1] * area + + for _ in range(self.max_attempts): + aspect_ratio = np.random.uniform(*self.aspect_ratio_range) + min_target_h = int( + round(math.sqrt(min_target_area / aspect_ratio))) + max_target_h = int( + round(math.sqrt(max_target_area / aspect_ratio))) + + if max_target_h * aspect_ratio > w: + max_target_h = int((w + 0.5 - 1e-7) / aspect_ratio) + if max_target_h * aspect_ratio > w: + max_target_h -= 1 + + max_target_h = min(max_target_h, h) + min_target_h = min(max_target_h, min_target_h) + + # slightly differs from tf implementation + target_h = int( + round(np.random.uniform(min_target_h, max_target_h))) + target_w = int(round(target_h * aspect_ratio)) + target_area = target_h * target_w + + # slight differs from tf. In tf, if target_area > max_target_area, + # area will be recalculated + if (target_area < min_target_area or target_area > max_target_area + or target_w > w or target_h > h + or target_area < self.min_covered * area): + continue + + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + # Fallback to central crop + img_short = min(h, w) + crop_size = self.scale[0] / (self.scale[0] + + self.crop_padding) * img_short + + offset_h = max(0, int(round((h - crop_size) / 2.))) + offset_w = max(0, int(round((w - crop_size) / 2.))) + return offset_h, offset_w, crop_size, crop_size + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = super().__repr__()[:-1] + repr_str += f', min_covered={self.min_covered}' + repr_str += f', crop_padding={self.crop_padding})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomErasing(BaseTransform): + """Randomly selects a rectangle region in an image and erase pixels. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + erase_prob (float): Probability that image will be randomly erased. + Default: 0.5 + min_area_ratio (float): Minimum erased area / input image area + Default: 0.02 + max_area_ratio (float): Maximum erased area / input image area + Default: 0.4 + aspect_range (sequence | float): Aspect ratio range of erased area. + if float, it will be converted to (aspect_ratio, 1/aspect_ratio) + Default: (3/10, 10/3) + mode (str): Fill method in erased area, can be: + + - const (default): All pixels are assign with the same value. + - rand: each pixel is assigned with a random value in [0, 255] + + fill_color (sequence | Number): Base color filled in erased area. + Defaults to (128, 128, 128). + fill_std (sequence | Number, optional): If set and ``mode`` is 'rand', + fill erased area with random color from normal distribution + (mean=fill_color, std=fill_std); If not set, fill erased area with + random color from uniform distribution (0~255). Defaults to None. + + Note: + See `Random Erasing Data Augmentation + `_ + + This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as + default. The config of these 4 modes are: + + - RE-R: RandomErasing(mode='rand') + - RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5)) + - RE-0: RandomErasing(mode='const', fill_color=0) + - RE-255: RandomErasing(mode='const', fill_color=255) + """ + + def __init__(self, + erase_prob=0.5, + min_area_ratio=0.02, + max_area_ratio=0.4, + aspect_range=(3 / 10, 10 / 3), + mode='const', + fill_color=(128, 128, 128), + fill_std=None): + assert isinstance(erase_prob, float) and 0. <= erase_prob <= 1. + assert isinstance(min_area_ratio, float) and 0. <= min_area_ratio <= 1. + assert isinstance(max_area_ratio, float) and 0. <= max_area_ratio <= 1. + assert min_area_ratio <= max_area_ratio, \ + 'min_area_ratio should be smaller than max_area_ratio' + if isinstance(aspect_range, float): + aspect_range = min(aspect_range, 1 / aspect_range) + aspect_range = (aspect_range, 1 / aspect_range) + assert isinstance(aspect_range, Sequence) and len(aspect_range) == 2 \ + and all(isinstance(x, float) for x in aspect_range), \ + 'aspect_range should be a float or Sequence with two float.' + assert all(x > 0 for x in aspect_range), \ + 'aspect_range should be positive.' + assert aspect_range[0] <= aspect_range[1], \ + 'In aspect_range (min, max), min should be smaller than max.' + assert mode in ['const', 'rand'], \ + 'Please select `mode` from ["const", "rand"].' + if isinstance(fill_color, Number): + fill_color = [fill_color] * 3 + assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \ + and all(isinstance(x, Number) for x in fill_color), \ + 'fill_color should be a float or Sequence with three int.' + if fill_std is not None: + if isinstance(fill_std, Number): + fill_std = [fill_std] * 3 + assert isinstance(fill_std, Sequence) and len(fill_std) == 3 \ + and all(isinstance(x, Number) for x in fill_std), \ + 'fill_std should be a float or Sequence with three int.' + + self.erase_prob = erase_prob + self.min_area_ratio = min_area_ratio + self.max_area_ratio = max_area_ratio + self.aspect_range = aspect_range + self.mode = mode + self.fill_color = fill_color + self.fill_std = fill_std + + def _fill_pixels(self, img, top, left, h, w): + """Fill pixels to the patch of image.""" + if self.mode == 'const': + patch = np.empty((h, w, 3), dtype=np.uint8) + patch[:, :] = np.array(self.fill_color, dtype=np.uint8) + elif self.fill_std is None: + # Uniform distribution + patch = np.random.uniform(0, 256, (h, w, 3)).astype(np.uint8) + else: + # Normal distribution + patch = np.random.normal(self.fill_color, self.fill_std, (h, w, 3)) + patch = np.clip(patch.astype(np.int32), 0, 255).astype(np.uint8) + + img[top:top + h, left:left + w] = patch + return img + + @cache_randomness + def random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.erase_prob + + @cache_randomness + def random_patch(self, img_h, img_w): + """Randomly generate patch the erase.""" + # convert the aspect ratio to log space to equally handle width and + # height. + log_aspect_range = np.log( + np.array(self.aspect_range, dtype=np.float32)) + aspect_ratio = np.exp(np.random.uniform(*log_aspect_range)) + area = img_h * img_w + area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio) + + h = min(int(round(np.sqrt(area * aspect_ratio))), img_h) + w = min(int(round(np.sqrt(area / aspect_ratio))), img_w) + top = np.random.randint(0, img_h - h) if img_h > h else 0 + left = np.random.randint(0, img_w - w) if img_w > w else 0 + return top, left, h, w + + def transform(self, results): + """ + Args: + results (dict): Results dict from pipeline + + Returns: + dict: Results after the transformation. + """ + if self.random_disable(): + return results + + img = results['img'] + img_h, img_w = img.shape[:2] + + img = self._fill_pixels(img, *self.random_patch(img_h, img_w)) + + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(erase_prob={self.erase_prob}, ' + repr_str += f'min_area_ratio={self.min_area_ratio}, ' + repr_str += f'max_area_ratio={self.max_area_ratio}, ' + repr_str += f'aspect_range={self.aspect_range}, ' + repr_str += f'mode={self.mode}, ' + repr_str += f'fill_color={self.fill_color}, ' + repr_str += f'fill_std={self.fill_std})' + return repr_str + + +@TRANSFORMS.register_module() +class EfficientNetCenterCrop(BaseTransform): + r"""EfficientNet style center crop. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + crop_size (int): Expected size after cropping with the format + of (h, w). + crop_padding (int): The crop padding parameter in efficientnet style + center crop. Defaults to 32. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if + ``efficientnet_style`` is True. Defaults to 'bicubic'. + backend (str): The image resize backend type, accepted values are + `cv2` and `pillow`. Only valid if efficientnet style is True. + Defaults to `cv2`. + Notes: + - If the image is smaller than the crop size, return the original + image. + - The pipeline will be to first + to perform the center crop with the ``crop_size_`` as: + + .. math:: + + \text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} + + \text{crop_padding}} \times \text{short_edge} + + And then the pipeline resizes the img to the input crop size. + """ + + def __init__(self, + crop_size: int, + crop_padding: int = 32, + interpolation: str = 'bicubic', + backend: str = 'cv2'): + assert isinstance(crop_size, int) + assert crop_size > 0 + assert crop_padding >= 0 + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') + + self.crop_size = crop_size + self.crop_padding = crop_padding + self.interpolation = interpolation + self.backend = backend + + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: EfficientNet style center cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + h, w = img.shape[:2] + + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa + img_short = min(h, w) + crop_size = self.crop_size / (self.crop_size + + self.crop_padding) * img_short + + offset_h = max(0, int(round((h - crop_size) / 2.))) + offset_w = max(0, int(round((w - crop_size) / 2.))) + + # crop the image + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + crop_size - 1, + offset_h + crop_size - 1 + ])) + # resize image + img = mmcv.imresize( + img, (self.crop_size, self.crop_size), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', crop_padding={self.crop_padding}' + repr_str += f', interpolation={self.interpolation}' + repr_str += f', backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeEdge(BaseTransform): + """Resize images along the specified edge. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + **Added Keys:** + + - scale + - scale_factor + + Args: + scale (int): The edge scale to resizing. + edge (str): The edge to resize. Defaults to 'short'. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. + Defaults to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + Defaults to 'bilinear'. + """ + + def __init__(self, + scale: int, + edge: str = 'short', + backend: str = 'cv2', + interpolation: str = 'bilinear') -> None: + allow_edges = ['short', 'long', 'width', 'height'] + assert edge in allow_edges, \ + f'Invalid edge "{edge}", please specify from {allow_edges}.' + self.edge = edge + self.scale = scale + self.backend = backend + self.interpolation = interpolation + + def _resize_img(self, results: dict) -> None: + """Resize images with ``results['scale']``.""" + + img, w_scale, h_scale = mmcv.imresize( + results['img'], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape[:2] + results['scale'] = img.shape[:2][::-1] + results['scale_factor'] = (w_scale, h_scale) + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img', 'scale', 'scale_factor', + 'img_shape' keys are updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + h, w = results['img'].shape[:2] + if any([ + # conditions to resize the width + self.edge == 'short' and w < h, + self.edge == 'long' and w > h, + self.edge == 'width', + ]): + width = self.scale + height = int(self.scale * h / w) + else: + height = self.scale + width = int(self.scale * w / h) + results['scale'] = (width, height) + + self._resize_img(results) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'edge={self.edge}, ' + repr_str += f'backend={self.backend}, ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class ColorJitter(BaseTransform): + """Randomly change the brightness, contrast and saturation of an image. + + Modified from + https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py + Licensed under the BSD 3-Clause License. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + brightness (float | Sequence[float] (min, max)): How much to jitter + brightness. brightness_factor is chosen uniformly from + ``[max(0, 1 - brightness), 1 + brightness]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + contrast (float | Sequence[float] (min, max)): How much to jitter + contrast. contrast_factor is chosen uniformly from + ``[max(0, 1 - contrast), 1 + contrast]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + saturation (float | Sequence[float] (min, max)): How much to jitter + saturation. saturation_factor is chosen uniformly from + ``[max(0, 1 - saturation), 1 + saturation]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + hue (float | Sequence[float] (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from ``[-hue, hue]`` (0 <= hue + <= 0.5) or the given ``[min, max]`` (-0.5 <= min <= max <= 0.5). + Defaults to 0. + backend (str): The backend to operate the image. Defaults to 'pillow' + """ + + def __init__(self, + brightness: Union[float, Sequence[float]] = 0., + contrast: Union[float, Sequence[float]] = 0., + saturation: Union[float, Sequence[float]] = 0., + hue: Union[float, Sequence[float]] = 0., + backend='pillow'): + self.brightness = self._set_range(brightness, 'brightness') + self.contrast = self._set_range(contrast, 'contrast') + self.saturation = self._set_range(saturation, 'saturation') + self.hue = self._set_range(hue, 'hue', center=0, bound=(-0.5, 0.5)) + self.backend = backend + + def _set_range(self, value, name, center=1, bound=(0, float('inf'))): + """Set the range of magnitudes.""" + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f'If {name} is a single number, it must be non negative.') + value = (center - float(value), center + float(value)) + + if isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + value = np.clip(value, bound[0], bound[1]) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning(f'ColorJitter {name} values exceed the bound ' + f'{bound}, clipped to the bound.') + else: + raise TypeError(f'{name} should be a single number ' + 'or a list/tuple with length 2.') + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + else: + value = tuple(value) + + return value + + @cache_randomness + def _rand_params(self): + """Get random parameters including magnitudes and indices of + transforms.""" + trans_inds = np.random.permutation(4) + b, c, s, h = (None, ) * 4 + + if self.brightness is not None: + b = np.random.uniform(self.brightness[0], self.brightness[1]) + if self.contrast is not None: + c = np.random.uniform(self.contrast[0], self.contrast[1]) + if self.saturation is not None: + s = np.random.uniform(self.saturation[0], self.saturation[1]) + if self.hue is not None: + h = np.random.uniform(self.hue[0], self.hue[1]) + + return trans_inds, b, c, s, h + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: ColorJitter results, 'img' key is updated in result dict. + """ + img = results['img'] + trans_inds, brightness, contrast, saturation, hue = self._rand_params() + + for index in trans_inds: + if index == 0 and brightness is not None: + img = mmcv.adjust_brightness( + img, brightness, backend=self.backend) + elif index == 1 and contrast is not None: + img = mmcv.adjust_contrast(img, contrast, backend=self.backend) + elif index == 2 and saturation is not None: + img = mmcv.adjust_color( + img, alpha=saturation, backend=self.backend) + elif index == 3 and hue is not None: + img = mmcv.adjust_hue(img, hue, backend=self.backend) + + results['img'] = img + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(brightness={self.brightness}, ' + repr_str += f'contrast={self.contrast}, ' + repr_str += f'saturation={self.saturation}, ' + repr_str += f'hue={self.hue})' + return repr_str + + +@TRANSFORMS.register_module() +class Lighting(BaseTransform): + """Adjust images lighting using AlexNet-style PCA jitter. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + eigval (Sequence[float]): the eigenvalue of the convariance matrix + of pixel values, respectively. + eigvec (list[list]): the eigenvector of the convariance matrix of + pixel values, respectively. + alphastd (float): The standard deviation for distribution of alpha. + Defaults to 0.1. + to_rgb (bool): Whether to convert img to rgb. Defaults to False. + """ + + def __init__(self, + eigval: Sequence[float], + eigvec: Sequence[float], + alphastd: float = 0.1, + to_rgb: bool = False): + assert isinstance(eigval, Sequence), \ + f'eigval must be Sequence, got {type(eigval)} instead.' + assert isinstance(eigvec, Sequence), \ + f'eigvec must be Sequence, got {type(eigvec)} instead.' + for vec in eigvec: + assert isinstance(vec, Sequence) and len(vec) == len(eigvec[0]), \ + 'eigvec must contains lists with equal length.' + assert isinstance(alphastd, float), 'alphastd should be of type ' \ + f'float or int, got {type(alphastd)} instead.' + + self.eigval = np.array(eigval) + self.eigvec = np.array(eigvec) + self.alphastd = alphastd + self.to_rgb = to_rgb + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Lightinged results, 'img' key is updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + img = results['img'] + img_lighting = mmcv.adjust_lighting( + img, + self.eigval, + self.eigvec, + alphastd=self.alphastd, + to_rgb=self.to_rgb) + results['img'] = img_lighting.astype(img.dtype) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(eigval={self.eigval.tolist()}, ' + repr_str += f'eigvec={self.eigvec.tolist()}, ' + repr_str += f'alphastd={self.alphastd}, ' + repr_str += f'to_rgb={self.to_rgb})' + return repr_str + + +# 'Albu' is used in previous versions of mmpretrain, here is for compatibility +# users can use both 'Albumentations' and 'Albu'. +@TRANSFORMS.register_module(['Albumentations', 'Albu']) +class Albumentations(BaseTransform): + """Wrapper to use augmentation from albumentations library. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Adds custom transformations from albumentations library. + More details can be found in + `Albumentations `_. + An example of ``transforms`` is as followed: + + .. code-block:: + + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + + Args: + transforms (List[Dict]): List of albumentations transform configs. + keymap (Optional[Dict]): Mapping of mmpretrain to albumentations + fields, in format {'input key':'albumentation-style key'}. + Defaults to None. + + Example: + >>> import mmcv + >>> from mmpretrain.datasets import Albumentations + >>> transforms = [ + ... dict( + ... type='ShiftScaleRotate', + ... shift_limit=0.0625, + ... scale_limit=0.0, + ... rotate_limit=0, + ... interpolation=1, + ... p=0.5), + ... dict( + ... type='RandomBrightnessContrast', + ... brightness_limit=[0.1, 0.3], + ... contrast_limit=[0.1, 0.3], + ... p=0.2), + ... dict(type='ChannelShuffle', p=0.1), + ... dict( + ... type='OneOf', + ... transforms=[ + ... dict(type='Blur', blur_limit=3, p=1.0), + ... dict(type='MedianBlur', blur_limit=3, p=1.0) + ... ], + ... p=0.1), + ... ] + >>> albu = Albumentations(transforms) + >>> data = {'img': mmcv.imread('./demo/demo.JPEG')} + >>> data = albu(data) + >>> print(data['img'].shape) + (375, 500, 3) + """ + + def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + else: + from albumentations import Compose as albu_Compose + + assert isinstance(transforms, list), 'transforms must be a list.' + if keymap is not None: + assert isinstance(keymap, dict), 'keymap must be None or a dict. ' + + self.transforms = transforms + + self.aug = albu_Compose( + [self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = dict(img='image') + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: Dict): + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \ + "transforms must be a dict with keyword 'type'." + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d, keymap): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def transform(self, results: Dict) -> Dict: + """Transform function to perform albumentations transforms. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Transformed results, 'img' and 'img_shape' keys are + updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + results = self.aug(**results) + + # back to the original format + results = self.mapper(results, self.keymap_back) + results['img_shape'] = results['img'].shape[:2] + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(transforms={repr(self.transforms)})' + return repr_str + + +@TRANSFORMS.register_module() +class SimMIMMaskGenerator(BaseTransform): + """Generate random block mask for each Image. + + **Added Keys**: + + - mask + + This module is used in SimMIM to generate masks. + + Args: + input_size (int): Size of input image. Defaults to 192. + mask_patch_size (int): Size of each block mask. Defaults to 32. + model_patch_size (int): Patch size of each token. Defaults to 4. + mask_ratio (float): The mask ratio of image. Defaults to 0.6. + """ + + def __init__(self, + input_size: int = 192, + mask_patch_size: int = 32, + model_patch_size: int = 4, + mask_ratio: float = 0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def transform(self, results: dict) -> dict: + """Method to generate random block mask for each Image in SimMIM. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask_idx = np.random.permutation(self.token_count)[:self.mask_count] + mask = np.zeros(self.token_count, dtype=int) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + results.update({'mask': mask}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(input_size={self.input_size}, ' + repr_str += f'mask_patch_size={self.mask_patch_size}, ' + repr_str += f'model_patch_size={self.model_patch_size}, ' + repr_str += f'mask_ratio={self.mask_ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class BEiTMaskGenerator(BaseTransform): + """Generate mask for image. + + **Added Keys**: + + - mask + + This module is borrowed from + https://github.com/microsoft/unilm/tree/master/beit + + Args: + input_size (int): The size of input image. + num_masking_patches (int): The number of patches to be masked. + min_num_patches (int): The minimum number of patches to be masked + in the process of generating mask. Defaults to 4. + max_num_patches (int, optional): The maximum number of patches to be + masked in the process of generating mask. Defaults to None. + min_aspect (float): The minimum aspect ratio of mask blocks. Defaults + to 0.3. + min_aspect (float, optional): The minimum aspect ratio of mask blocks. + Defaults to None. + """ + + def __init__(self, + input_size: int, + num_masking_patches: int, + min_num_patches: int = 4, + max_num_patches: Optional[int] = None, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None) -> None: + if not isinstance(input_size, tuple): + input_size = (input_size, ) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + + self.num_masking_patches = num_masking_patches + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None \ + else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int: + """Generate mask recursively. + + Args: + mask (np.ndarray): The mask to be generated. + max_mask_patches (int): The maximum number of patches to be masked. + + Returns: + int: The number of patches masked. + """ + delta = 0 + for _ in range(10): + target_area = np.random.uniform(self.min_num_patches, + max_mask_patches) + aspect_ratio = math.exp(np.random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = np.random.randint(0, self.height - h) + left = np.random.randint(0, self.width - w) + + num_masked = mask[top:top + h, left:left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + if delta > 0: + break + return delta + + def transform(self, results: dict) -> dict: + """Method to generate random block mask for each Image in BEiT. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask = np.zeros(shape=(self.height, self.width), dtype=int) + + mask_count = 0 + while mask_count != self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + mask_count += delta + results.update({'mask': mask}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(height={self.height}, ' + repr_str += f'width={self.width}, ' + repr_str += f'num_patches={self.num_patches}, ' + repr_str += f'num_masking_patches={self.num_masking_patches}, ' + repr_str += f'min_num_patches={self.min_num_patches}, ' + repr_str += f'max_num_patches={self.max_num_patches}, ' + repr_str += f'log_aspect_ratio={self.log_aspect_ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform): + """Crop the given PIL Image to random size and aspect ratio with random + interpolation. + + **Required Keys**: + + - img + + **Modified Keys**: + + - img + + **Added Keys**: + + - target_img + + This module is borrowed from + https://github.com/microsoft/unilm/tree/master/beit. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a + random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio + is made. This crop is finally resized to given size. This is popularly used + to train the Inception networks. This module first crops the image and + resizes the crop to two different sizes. + + Args: + size (Union[tuple, int]): Expected output size of each edge of the + first image. + second_size (Union[tuple, int], optional): Expected output size of each + edge of the second image. + scale (tuple[float, float]): Range of size of the origin size cropped. + Defaults to (0.08, 1.0). + ratio (tuple[float, float]): Range of aspect ratio of the origin aspect + ratio cropped. Defaults to (3./4., 4./3.). + interpolation (str): The interpolation for the first image. Defaults + to ``bilinear``. + second_interpolation (str): The interpolation for the second image. + Defaults to ``lanczos``. + """ + + def __init__(self, + size: Union[tuple, int], + second_size=None, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation='bilinear', + second_interpolation='lanczos') -> None: + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if second_size is not None: + if isinstance(second_size, tuple): + self.second_size = second_size + else: + self.second_size = (second_size, second_size) + else: + self.second_size = None + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + ('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = ('bilinear', 'bicubic') + else: + self.interpolation = interpolation + self.second_interpolation = second_interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img: np.ndarray, scale: tuple, + ratio: tuple) -> Sequence[int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (np.ndarray): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect + ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + img_h, img_w = img.shape[:2] + area = img_h * img_w + + for _ in range(10): + target_area = np.random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img_w and h < img_h: + i = np.random.randint(0, img_h - h) + j = np.random.randint(0, img_w - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img_w / img_h + if in_ratio < min(ratio): + w = img_w + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img_h + w = int(round(h * max(ratio))) + else: # whole image + w = img_w + h = img_h + i = (img_h - h) // 2 + j = (img_w - w) // 2 + return i, j, h, w + + def transform(self, results: dict) -> dict: + """Crop the given image and resize it to two different sizes. + + This module crops the given image randomly and resize the crop to two + different sizes. This is popularly used in BEiT-style masked image + modeling, where an off-the-shelf model is used to provide the target. + + Args: + results (dict): Results from previous pipeline. + + Returns: + dict: Results after applying this transformation. + """ + img = results['img'] + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = np.random.choice(self.interpolation) + else: + interpolation = self.interpolation + if self.second_size is None: + img = img[i:i + h, j:j + w] + img = mmcv.imresize(img, self.size, interpolation=interpolation) + results.update({'img': img}) + else: + img = img[i:i + h, j:j + w] + img_sample = mmcv.imresize( + img, self.size, interpolation=interpolation) + img_target = mmcv.imresize( + img, self.second_size, interpolation=self.second_interpolation) + results.update({'img': [img_sample, img_target]}) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, ' + repr_str += f'second_size={self.second_size}, ' + repr_str += f'interpolation={self.interpolation}, ' + repr_str += f'second_interpolation={self.second_interpolation}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'ratio={self.ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class CleanCaption(BaseTransform): + """Clean caption text. + + Remove some useless punctuation for the caption task. + + **Required Keys:** + + - ``*keys`` + + **Modified Keys:** + + - ``*keys`` + + Args: + keys (Sequence[str], optional): The keys of text to be cleaned. + Defaults to 'gt_caption'. + remove_chars (str): The characters to be removed. Defaults to + :py:attr:`string.punctuation`. + lowercase (bool): Whether to convert the text to lowercase. + Defaults to True. + remove_dup_space (bool): Whether to remove duplicated whitespaces. + Defaults to True. + strip (bool): Whether to remove leading and trailing whitespaces. + Defaults to True. + """ + + def __init__( + self, + keys='gt_caption', + remove_chars=string.punctuation, + lowercase=True, + remove_dup_space=True, + strip=True, + ): + if isinstance(keys, str): + keys = [keys] + self.keys = keys + self.transtab = str.maketrans({ch: None for ch in remove_chars}) + self.lowercase = lowercase + self.remove_dup_space = remove_dup_space + self.strip = strip + + def _clean(self, text): + """Perform text cleaning before tokenizer.""" + + if self.strip: + text = text.strip() + + text = text.translate(self.transtab) + + if self.remove_dup_space: + text = re.sub(r'\s{2,}', ' ', text) + + if self.lowercase: + text = text.lower() + + return text + + def clean(self, text): + """Perform text cleaning before tokenizer.""" + if isinstance(text, (list, tuple)): + return [self._clean(item) for item in text] + elif isinstance(text, str): + return self._clean(text) + else: + raise TypeError('text must be a string or a list of strings') + + def transform(self, results: dict) -> dict: + """Method to clean the input text data.""" + for key in self.keys: + results[key] = self.clean(results[key]) + return results + + +@TRANSFORMS.register_module() +class OFAAddObjects(BaseTransform): + + def transform(self, results: dict) -> dict: + if 'objects' not in results: + raise ValueError( + 'Some OFA fine-tuned models requires `objects` field in the ' + 'dataset, which is generated by VinVL. Or please use ' + 'zero-shot configs. See ' + 'https://github.com/OFA-Sys/OFA/issues/189') + + if 'question' in results: + prompt = '{} object: {}'.format( + results['question'], + ' '.join(results['objects']), + ) + results['decoder_prompt'] = prompt + results['question'] = prompt + + +@TRANSFORMS.register_module() +class RandomTranslatePad(BaseTransform): + + def __init__(self, size=640, aug_translate=False): + self.size = size + self.aug_translate = aug_translate + + @cache_randomness + def rand_translate_params(self, dh, dw): + top = np.random.randint(0, dh) + left = np.random.randint(0, dw) + return top, left + + def transform(self, results: dict) -> dict: + img = results['img'] + h, w = img.shape[:-1] + dw = self.size - w + dh = self.size - h + if self.aug_translate: + top, left = self.rand_translate_params(dh, dw) + else: + top = round(dh / 2.0 - 0.1) + left = round(dw / 2.0 - 0.1) + + out_img = np.zeros((self.size, self.size, 3), dtype=np.float32) + out_img[top:top + h, left:left + w, :] = img + results['img'] = out_img + results['img_shape'] = (self.size, self.size) + + # translate box + if 'gt_bboxes' in results.keys(): + for i in range(len(results['gt_bboxes'])): + box = results['gt_bboxes'][i] + box[0], box[2] = box[0] + left, box[2] + left + box[1], box[3] = box[1] + top, box[3] + top + results['gt_bboxes'][i] = box + + return results diff --git a/mmpretrain/datasets/transforms/utils.py b/mmpretrain/datasets/transforms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7940486fc9904c14f5a5a4a959022c11456c968 --- /dev/null +++ b/mmpretrain/datasets/transforms/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Union + +from mmcv.transforms import BaseTransform + +PIPELINE_TYPE = List[Union[dict, BaseTransform]] + + +def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int: + """Returns the index of the transform in a pipeline. + + Args: + pipeline (List[dict] | List[BaseTransform]): The transforms list. + target (str): The target transform class name. + + Returns: + int: The transform index. Returns -1 if not found. + """ + for i, transform in enumerate(pipeline): + if isinstance(transform, dict): + if isinstance(transform['type'], type): + if transform['type'].__name__ == target: + return i + else: + if transform['type'] == target: + return i + else: + if transform.__class__.__name__ == target: + return i + + return -1 + + +def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False): + """Remove the target transform type from the pipeline. + + Args: + pipeline (List[dict] | List[BaseTransform]): The transforms list. + target (str): The target transform class name. + inplace (bool): Whether to modify the pipeline inplace. + + Returns: + The modified transform. + """ + idx = get_transform_idx(pipeline, target) + if not inplace: + pipeline = copy.deepcopy(pipeline) + while idx >= 0: + pipeline.pop(idx) + idx = get_transform_idx(pipeline, target) + + return pipeline diff --git a/mmpretrain/datasets/transforms/wrappers.py b/mmpretrain/datasets/transforms/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dfd730b4db0dc80ed315b79658cfbf683e4035 --- /dev/null +++ b/mmpretrain/datasets/transforms/wrappers.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Callable, List, Union + +from mmcv.transforms import BaseTransform, Compose + +from mmpretrain.registry import TRANSFORMS + +# Define type of transform or transform config +Transform = Union[dict, Callable[[dict], dict]] + + +@TRANSFORMS.register_module() +class MultiView(BaseTransform): + """A transform wrapper for multiple views of an image. + + Args: + transforms (list[dict | callable], optional): Sequence of transform + object or config dict to be wrapped. + mapping (dict): A dict that defines the input key mapping. + The keys corresponds to the inner key (i.e., kwargs of the + ``transform`` method), and should be string type. The values + corresponds to the outer keys (i.e., the keys of the + data/results), and should have a type of string, list or dict. + None means not applying input mapping. Default: None. + allow_nonexist_keys (bool): If False, the outer keys in the mapping + must exist in the input data, or an exception will be raised. + Default: False. + + Examples: + >>> # Example 1: MultiViews 1 pipeline with 2 views + >>> pipeline = [ + >>> dict(type='MultiView', + >>> num_views=2, + >>> transforms=[ + >>> [ + >>> dict(type='Resize', scale=224))], + >>> ]) + >>> ] + >>> # Example 2: MultiViews 2 pipelines, the first with 2 views, + >>> # the second with 6 views + >>> pipeline = [ + >>> dict(type='MultiView', + >>> num_views=[2, 6], + >>> transforms=[ + >>> [ + >>> dict(type='Resize', scale=224)], + >>> [ + >>> dict(type='Resize', scale=224), + >>> dict(type='RandomSolarize')], + >>> ]) + >>> ] + """ + + def __init__(self, transforms: List[List[Transform]], + num_views: Union[int, List[int]]) -> None: + + if isinstance(num_views, int): + num_views = [num_views] + assert isinstance(num_views, List) + assert len(num_views) == len(transforms) + self.num_views = num_views + + self.pipelines = [] + for trans in transforms: + pipeline = Compose(trans) + self.pipelines.append(pipeline) + + self.transforms = [] + for i in range(len(num_views)): + self.transforms.extend([self.pipelines[i]] * num_views[i]) + + def transform(self, results: dict) -> dict: + """Apply transformation to inputs. + + Args: + results (dict): Result dict from previous pipelines. + + Returns: + dict: Transformed results. + """ + multi_views_outputs = dict(img=[]) + for trans in self.transforms: + inputs = copy.deepcopy(results) + outputs = trans(inputs) + + multi_views_outputs['img'].append(outputs['img']) + results.update(multi_views_outputs) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + '(' + for i, p in enumerate(self.pipelines): + repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n' + repr_str += str(p) + repr_str += ')' + return repr_str + + +@TRANSFORMS.register_module() +class ApplyToList(BaseTransform): + """A transform wrapper to apply the wrapped transforms to a list of items. + For example, to load and resize a list of images. + + Args: + transforms (list[dict | callable]): Sequence of transform config dict + to be wrapped. + scatter_key (str): The key to scatter data dict. If the field is a + list, scatter the list to multiple data dicts to do transformation. + collate_keys (List[str]): The keys to collate from multiple data dicts. + The fields in ``collate_keys`` will be composed into a list after + transformation, and the other fields will be adopted from the + first data dict. + """ + + def __init__(self, transforms, scatter_key, collate_keys): + super().__init__() + + self.transforms = Compose([TRANSFORMS.build(t) for t in transforms]) + self.scatter_key = scatter_key + self.collate_keys = set(collate_keys) + self.collate_keys.add(self.scatter_key) + + def transform(self, results: dict): + scatter_field = results.get(self.scatter_key) + + if isinstance(scatter_field, list): + scattered_results = [] + for item in scatter_field: + single_results = copy.deepcopy(results) + single_results[self.scatter_key] = item + scattered_results.append(self.transforms(single_results)) + + final_output = scattered_results[0] + + # merge output list to single output + for key in scattered_results[0].keys(): + if key in self.collate_keys: + final_output[key] = [ + single[key] for single in scattered_results + ] + return final_output + else: + return self.transforms(results) diff --git a/mmpretrain/datasets/utils.py b/mmpretrain/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb60e432c374c1a904700a7348f706fa0e523eb --- /dev/null +++ b/mmpretrain/datasets/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import hashlib +import os +import os.path +import shutil +import tarfile +import tempfile +import urllib.error +import urllib.request +import zipfile + +from mmengine.fileio import LocalBackend, get_file_backend + +__all__ = [ + 'rm_suffix', 'check_integrity', 'download_and_extract_archive', + 'open_maybe_compressed_file' +] + + +def rm_suffix(s, suffix=None): + if suffix is None: + return s[:s.rfind('.')] + else: + return s[:s.rfind(suffix)] + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024): + md5 = hashlib.md5() + backend = get_file_backend(fpath, enable_singleton=True) + if isinstance(backend, LocalBackend): + # Enable chunk update for local file. + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + else: + md5.update(backend.get(fpath)) + return md5.hexdigest() + + +def check_md5(fpath, md5, **kwargs): + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath, md5=None): + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + """Download object at the given URL to a local path. + + Modified from + https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file + + Args: + url (str): URL of the object to download + dst (str): Full path where object will be saved, + e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded + file should start with ``hash_prefix``. Defaults to None. + progress (bool): whether or not to display a progress bar to stderr. + Defaults to True + """ + file_size = None + req = urllib.request.Request(url) + u = urllib.request.urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders('Content-Length') + else: + content_length = meta.get_all('Content-Length') + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after download is + # complete. This prevents a local file being overridden by a broken + # download. + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + import rich.progress + columns = [ + rich.progress.DownloadColumn(), + rich.progress.BarColumn(bar_width=None), + rich.progress.TimeRemainingColumn(), + ] + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with rich.progress.Progress(*columns) as pbar: + task = pbar.add_task('download', total=file_size, visible=progress) + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(task, advance=len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError( + 'invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def download_url(url, root, filename=None, md5=None): + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from. + root (str): Directory to place downloaded file in. + filename (str | None): Name to save the file under. + If filename is None, use the basename of the URL. + md5 (str | None): MD5 checksum of the download. + If md5 is None, download without md5 check. + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f'Using downloaded and verified file: {fpath}') + else: + try: + print(f'Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + except (urllib.error.URLError, IOError) as e: + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + f' Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError('File not found or corrupted.') + + +def _is_tarxz(filename): + return filename.endswith('.tar.xz') + + +def _is_tar(filename): + return filename.endswith('.tar') + + +def _is_targz(filename): + return filename.endswith('.tar.gz') + + +def _is_tgz(filename): + return filename.endswith('.tgz') + + +def _is_gzip(filename): + return filename.endswith('.gz') and not filename.endswith('.tar.gz') + + +def _is_zip(filename): + return filename.endswith('.zip') + + +def extract_archive(from_path, to_path=None, remove_finished=False): + if to_path is None: + to_path = os.path.dirname(from_path) + + if _is_tar(from_path): + with tarfile.open(from_path, 'r') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path) or _is_tgz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_tarxz(from_path): + with tarfile.open(from_path, 'r:xz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join( + to_path, + os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError(f'Extraction of {from_path} not supported') + + if remove_finished: + os.remove(from_path) + + +def download_and_extract_archive(url, + download_root, + extract_root=None, + filename=None, + md5=None, + remove_finished=False): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f'Extracting {archive} to {extract_root}') + extract_archive(archive, extract_root, remove_finished) + + +def open_maybe_compressed_file(path: str): + """Return a file object that possibly decompresses 'path' on the fly. + + Decompression occurs when argument `path` is a string and ends with '.gz' + or '.xz'. + """ + if not isinstance(path, str): + return path + if path.endswith('.gz'): + import gzip + return gzip.open(path, 'rb') + if path.endswith('.xz'): + import lzma + return lzma.open(path, 'rb') + return open(path, 'rb') diff --git a/mmpretrain/datasets/vg_vqa.py b/mmpretrain/datasets/vg_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..2d83884c804086c060bcfe27e833bff28dc28e9e --- /dev/null +++ b/mmpretrain/datasets/vg_vqa.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.fileio import load + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class VGVQA(BaseDataset): + """Visual Genome VQA dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list. + + Compare to BaseDataset, the only difference is that coco_vqa annotation + file is already a list of data. There is no 'metainfo'. + """ + + raw_data_list = load(self.ann_file) + if not isinstance(raw_data_list, list): + raise TypeError( + f'The VQA annotations loaded from annotation file ' + f'should be a dict, but got {type(raw_data_list)}!') + + # load and parse data_infos. + data_list = [] + for raw_data_info in raw_data_list: + # parse raw data information to target format + data_info = self.parse_data_info(raw_data_info) + if isinstance(data_info, dict): + # For VQA tasks, each `data_info` looks like: + # { + # "question_id": 986769, + # "question": "How many people are there?", + # "answer": "two", + # "image": "image/1.jpg", + # "dataset": "vg" + # } + + # change 'image' key to 'img_path' + # TODO: This process will be removed, after the annotation file + # is preprocess. + data_info['img_path'] = data_info['image'] + del data_info['image'] + + if 'answer' in data_info: + # add answer_weight & answer_count, delete duplicate answer + if data_info['dataset'] == 'vqa': + answer_weight = {} + for answer in data_info['answer']: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len( + data_info['answer']) + else: + answer_weight[answer] = 1 / len( + data_info['answer']) + + data_info['answer'] = list(answer_weight.keys()) + data_info['answer_weight'] = list( + answer_weight.values()) + data_info['answer_count'] = len(answer_weight) + + elif data_info['dataset'] == 'vg': + data_info['answers'] = [data_info['answer']] + data_info['answer_weight'] = [0.2] + data_info['answer_count'] = 1 + + data_list.append(data_info) + + else: + raise TypeError( + f'Each VQA data element loaded from annotation file ' + f'should be a dict, but got {type(data_info)}!') + + return data_list diff --git a/mmpretrain/datasets/visual_genome.py b/mmpretrain/datasets/visual_genome.py new file mode 100644 index 0000000000000000000000000000000000000000..8c33b86c4f81d0be0f2830618ad100196b461dcf --- /dev/null +++ b/mmpretrain/datasets/visual_genome.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from itertools import chain +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VisualGenomeQA(BaseDataset): + """Visual Genome Question Answering dataset. + + dataset structure: :: + + data_root + ├── image + │   ├── 1.jpg + │   ├── 2.jpg + │   └── ... + └── question_answers.json + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. Defaults to ``"image"``. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to ``"question_answers.json"``. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str = 'image', + ann_file: str = 'question_answers.json', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _create_image_index(self): + img_prefix = self.data_prefix['img_path'] + + files = mmengine.list_dir_or_file(img_prefix, list_dir=False) + image_index = {} + for file in files: + image_id = re.findall(r'\d+', file) + if len(image_id) > 0: + image_id = int(image_id[-1]) + image_index[image_id] = mmengine.join_path(img_prefix, file) + + return image_index + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + # The original Visual Genome annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + data_list = [] + for qas in chain.from_iterable(ann['qas'] for ann in annotations): + # ann example + # { + # 'id': 1, + # 'qas': [ + # { + # 'a_objects': [], + # 'question': 'What color is the clock?', + # 'image_id': 1, + # 'qa_id': 986768, + # 'answer': 'Two.', + # 'q_objects': [], + # } + # ... + # ] + # } + + data_info = { + 'img_path': self.image_index[qas['image_id']], + 'quesiton': qas['quesiton'], + 'question_id': qas['question_id'], + 'image_id': qas['image_id'], + 'gt_answer': [qas['answer']], + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/vizwiz.py b/mmpretrain/datasets/vizwiz.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5dd394524cac5ad514351ac2a93286c75e1b17 --- /dev/null +++ b/mmpretrain/datasets/vizwiz.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VizWiz(BaseDataset): + """VizWiz dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # { + # "image": "VizWiz_val_00000001.jpg", + # "question": "Can you tell me what this medicine is please?", + # "answers": [ + # { + # "answer": "no", + # "answer_confidence": "yes" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time cold medicine", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time medicine", + # "answer_confidence": "yes" + # } + # ], + # "answer_type": "other", + # "answerable": 1 + # }, + data_info = dict() + data_info['question'] = ann['question'] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + + if 'answerable' not in ann: + data_list.append(data_info) + else: + if ann['answerable'] == 1: + # add answer_weight & answer_count, delete duplicate answer + answers = [] + for item in ann.pop('answers'): + if item['answer_confidence'] == 'yes' and item[ + 'answer'] != 'unanswerable': + answers.append(item['answer']) + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + # data_info.update(ann) + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/voc.py b/mmpretrain/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..39544de7a1794a2d965189c692f652cc56b218f9 --- /dev/null +++ b/mmpretrain/datasets/voc.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import xml.etree.ElementTree as ET +from typing import List, Optional, Union + +from mmengine import get_file_backend, list_from_file +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import expanduser +from .categories import VOC2007_CATEGORIES +from .multi_label import MultiLabelDataset + + +@DATASETS.register_module() +class VOC(MultiLabelDataset): + """`Pascal VOC `_ Dataset. + + After decompression, the dataset directory structure is as follows: + + VOC dataset directory: :: + + VOC2007 + ├── JPEGImages + │ ├── xxx.jpg + │ ├── xxy.jpg + │ └── ... + ├── Annotations + │ ├── xxx.xml + │ ├── xxy.xml + │ └── ... + └── ImageSets + └── Main + ├── train.txt + ├── val.txt + ├── trainval.txt + ├── test.txt + └── ... + + Extra difficult label is in VOC annotations, we will use + `gt_label_difficult` to record the difficult labels in each sample + and corresponding evaluation should take care of this field + to calculate metrics. Usually, difficult labels are reckoned as + negative in defaults. + + Args: + data_root (str): The root directory for VOC dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + image_set_path (str, optional): The path of image set, The file which + lists image ids of the sub dataset, and this path is relative + to ``data_root``. Default to ''. + data_prefix (dict): Prefix for data and annotation, keyword + 'img_path' and 'ann_path' can be set. Defaults to be + ``dict(img_path='JPEGImages', ann_path='Annotations')``. + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import VOC + >>> train_dataset = VOC(data_root='data/VOC2007', split='trainval') + >>> train_dataset + Dataset VOC + Number of samples: 5011 + Number of categories: 20 + Prefix of dataset: data/VOC2007 + Path of image set: data/VOC2007/ImageSets/Main/trainval.txt + Prefix of images: data/VOC2007/JPEGImages + Prefix of annotations: data/VOC2007/Annotations + >>> test_dataset = VOC(data_root='data/VOC2007', split='test') + >>> test_dataset + Dataset VOC + Number of samples: 4952 + Number of categories: 20 + Prefix of dataset: data/VOC2007 + Path of image set: data/VOC2007/ImageSets/Main/test.txt + Prefix of images: data/VOC2007/JPEGImages + Prefix of annotations: data/VOC2007/Annotations + """ # noqa: E501 + + METAINFO = {'classes': VOC2007_CATEGORIES} + + def __init__(self, + data_root: str, + split: str = 'trainval', + image_set_path: str = '', + data_prefix: Union[str, dict] = dict( + img_path='JPEGImages', ann_path='Annotations'), + test_mode: bool = False, + metainfo: Optional[dict] = None, + **kwargs): + + self.backend = get_file_backend(data_root, enable_singleton=True) + + if split: + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + if not data_prefix: + data_prefix = dict( + img_path='JPEGImages', ann_path='Annotations') + if not image_set_path: + image_set_path = self.backend.join_path( + 'ImageSets', 'Main', f'{split}.txt') + + # To handle the BC-breaking + if (split == 'train' or split == 'trainval') and test_mode: + logger = MMLogger.get_current_instance() + logger.warning(f'split="{split}" but test_mode=True. ' + f'The {split} set will be used.') + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + assert isinstance(data_prefix, dict) and 'img_path' in data_prefix, \ + '`data_prefix` must be a dict with key img_path' + + if (split and split not in ['val', 'test']) or not test_mode: + assert 'ann_path' in data_prefix and data_prefix[ + 'ann_path'] is not None, \ + '"ann_path" must be set in `data_prefix`' \ + 'when validation or test set is used.' + + self.data_root = data_root + self.image_set_path = self.backend.join_path(data_root, image_set_path) + + super().__init__( + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + @property + def ann_prefix(self): + """The prefix of images.""" + if 'ann_path' in self.data_prefix: + return self.data_prefix['ann_path'] + else: + return None + + def _get_labels_from_xml(self, img_id): + """Get gt_labels and labels_difficult from xml file.""" + xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml') + content = self.backend.get(xml_path) + root = ET.fromstring(content) + + labels, labels_difficult = set(), set() + for obj in root.findall('object'): + label_name = obj.find('name').text + # in case customized dataset has wrong labels + # or CLASSES has been override. + if label_name not in self.CLASSES: + continue + label = self.class_to_idx[label_name] + difficult = int(obj.find('difficult').text) + if difficult: + labels_difficult.add(label) + else: + labels.add(label) + + return list(labels), list(labels_difficult) + + def load_data_list(self): + """Load images and ground truth labels.""" + data_list = [] + img_ids = list_from_file(self.image_set_path) + + for img_id in img_ids: + img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg') + + labels, labels_difficult = None, None + if self.ann_prefix is not None: + labels, labels_difficult = self._get_labels_from_xml(img_id) + + info = dict( + img_path=img_path, + gt_label=labels, + gt_label_difficult=labels_difficult) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Prefix of dataset: \t{self.data_root}', + f'Path of image set: \t{self.image_set_path}', + f'Prefix of images: \t{self.img_prefix}', + f'Prefix of annotations: \t{self.ann_prefix}' + ] + + return body diff --git a/mmpretrain/datasets/vsr.py b/mmpretrain/datasets/vsr.py new file mode 100644 index 0000000000000000000000000000000000000000..7b109592bd020d57e3db8f2ff610901e2a1d9f31 --- /dev/null +++ b/mmpretrain/datasets/vsr.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VSR(BaseDataset): + """VSR: Visual Spatial Reasoning dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # "image": "train2017/000000372029.jpg", + # "question": "The dog is on the surfboard.", + # "answer": true + # } + data_info = dict() + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = 'yes' if ann['answer'] else 'no' + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/engine/__init__.py b/mmpretrain/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..332fea0909b4abdc6a83cf7662ea916a777d99dd --- /dev/null +++ b/mmpretrain/engine/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401, F403 +from .optimizers import * # noqa: F401, F403 +from .runners import * # noqa: F401, F403 +from .schedulers import * # noqa: F401, F403 diff --git a/mmpretrain/engine/hooks/__init__.py b/mmpretrain/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9e22be7e96d636f202066f2e00e7699b730619 --- /dev/null +++ b/mmpretrain/engine/hooks/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_num_check_hook import ClassNumCheckHook +from .densecl_hook import DenseCLHook +from .ema_hook import EMAHook +from .margin_head_hooks import SetAdaptiveMarginsHook +from .precise_bn_hook import PreciseBNHook +from .retriever_hooks import PrepareProtoBeforeValLoopHook +from .simsiam_hook import SimSiamHook +from .swav_hook import SwAVHook +from .switch_recipe_hook import SwitchRecipeHook +from .visualization_hook import VisualizationHook +from .warmup_param_hook import WarmupParamHook + +__all__ = [ + 'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook', + 'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook', + 'SetAdaptiveMarginsHook', 'EMAHook', 'SimSiamHook', 'DenseCLHook', + 'SwAVHook', 'WarmupParamHook' +] diff --git a/mmpretrain/engine/hooks/class_num_check_hook.py b/mmpretrain/engine/hooks/class_num_check_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..38170d6604810c575aa5c2c9435c0b75cfa761b2 --- /dev/null +++ b/mmpretrain/engine/hooks/class_num_check_hook.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved +from mmengine.hooks import Hook +from mmengine.utils import is_seq_of + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class ClassNumCheckHook(Hook): + """Class Number Check HOOK.""" + + def _check_head(self, runner, dataset): + """Check whether the `num_classes` in head matches the length of + `CLASSES` in `dataset`. + + Args: + runner (obj:`Runner`): runner object. + dataset (obj: `BaseDataset`): the dataset to check. + """ + model = runner.model + if dataset.CLASSES is None: + runner.logger.warning( + f'Please set class information in `metainfo` ' + f'in the {dataset.__class__.__name__} and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + assert is_seq_of(dataset.CLASSES, str), \ + (f'Class information in `metainfo` in ' + f'{dataset.__class__.__name__} should be a tuple of str.') + for _, module in model.named_modules(): + if hasattr(module, 'num_classes'): + assert module.num_classes == len(dataset.CLASSES), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of class information in `metainfo` ' + f'{len(dataset.CLASSES)}) in ' + f'{dataset.__class__.__name__}') + + def before_train(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj: `IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.train_dataloader.dataset) + + def before_val(self, runner): + """Check whether the validation dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.val_dataloader.dataset) + + def before_test(self, runner): + """Check whether the test dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.test_dataloader.dataset) diff --git a/mmpretrain/engine/hooks/densecl_hook.py b/mmpretrain/engine/hooks/densecl_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7e17d3419cbc2a540d3aecd81e223eed670df2 --- /dev/null +++ b/mmpretrain/engine/hooks/densecl_hook.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class DenseCLHook(Hook): + """Hook for DenseCL. + + This hook includes ``loss_lambda`` warmup in DenseCL. + Borrowed from the authors' code: ``_. + + Args: + start_iters (int): The number of warmup iterations to set + ``loss_lambda=0``. Defaults to 1000. + """ + + def __init__(self, start_iters: int = 1000) -> None: + self.start_iters = start_iters + + def before_train(self, runner) -> None: + """Obtain ``loss_lambda`` from algorithm.""" + assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ + "The runner must have attribute \"loss_lambda\" in DenseCL." + self.loss_lambda = get_ori_model(runner.model).loss_lambda + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """Adjust ``loss_lambda`` every train iter.""" + assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ + "The runner must have attribute \"loss_lambda\" in DenseCL." + cur_iter = runner.iter + if cur_iter >= self.start_iters: + get_ori_model(runner.model).loss_lambda = self.loss_lambda + else: + get_ori_model(runner.model).loss_lambda = 0. diff --git a/mmpretrain/engine/hooks/ema_hook.py b/mmpretrain/engine/hooks/ema_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..284d211b628c411f0eb712d1c558dc6aa2eb8996 --- /dev/null +++ b/mmpretrain/engine/hooks/ema_hook.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import itertools +import warnings +from typing import Dict, Optional + +from mmengine.hooks import EMAHook as BaseEMAHook +from mmengine.logging import MMLogger +from mmengine.runner import Runner + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class EMAHook(BaseEMAHook): + """A Hook to apply Exponential Moving Average (EMA) on the model during + training. + + Comparing with :class:`mmengine.hooks.EMAHook`, this hook accepts + ``evaluate_on_ema`` and ``evaluate_on_origin`` arguments. By default, the + ``evaluate_on_ema`` is enabled, and if you want to do validation and + testing on both original and EMA models, please set both arguments + ``True``. + + Note: + - EMAHook takes priority over CheckpointHook. + - The original model parameters are actually saved in ema field after + train. + - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. + + Args: + ema_type (str): The type of EMA strategy to use. You can find the + supported strategies in :mod:`mmengine.model.averaged_model`. + Defaults to 'ExponentialMovingAverage'. + strict_load (bool): Whether to strictly enforce that the keys of + ``state_dict`` in checkpoint match the keys returned by + ``self.module.state_dict``. Defaults to False. + Changed in v0.3.0. + begin_iter (int): The number of iteration to enable ``EMAHook``. + Defaults to 0. + begin_epoch (int): The number of epoch to enable ``EMAHook``. + Defaults to 0. + evaluate_on_ema (bool): Whether to evaluate (validate and test) + on EMA model during val-loop and test-loop. Defaults to True. + evaluate_on_origin (bool): Whether to evaluate (validate and test) + on the original model during val-loop and test-loop. + Defaults to False. + **kwargs: Keyword arguments passed to subclasses of + :obj:`BaseAveragedModel` + """ + + priority = 'NORMAL' + + def __init__(self, + ema_type: str = 'ExponentialMovingAverage', + strict_load: bool = False, + begin_iter: int = 0, + begin_epoch: int = 0, + evaluate_on_ema: bool = True, + evaluate_on_origin: bool = False, + **kwargs): + super().__init__( + ema_type=ema_type, + strict_load=strict_load, + begin_iter=begin_iter, + begin_epoch=begin_epoch, + **kwargs) + + if not evaluate_on_ema and not evaluate_on_origin: + warnings.warn( + 'Automatically set `evaluate_on_origin=True` since the ' + '`evaluate_on_ema` is disabled. If you want to disable ' + 'all validation, please modify the `val_interval` of ' + 'the `train_cfg`.', UserWarning) + evaluate_on_origin = True + + self.evaluate_on_ema = evaluate_on_ema + self.evaluate_on_origin = evaluate_on_origin + self.load_ema_from_ckpt = False + + def before_train(self, runner) -> None: + super().before_train(runner) + if not runner._resume and self.load_ema_from_ckpt: + # If loaded EMA state dict but not want to resume training + # overwrite the EMA state dict with the source model. + MMLogger.get_current_instance().info( + 'Load from a checkpoint with EMA parameters but not ' + 'resume training. Initialize the model parameters with ' + 'EMA parameters') + for p_ema, p_src in zip(self._ema_params, self._src_params): + p_src.data.copy_(p_ema.data) + + def before_val_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before + validation. + + Args: + runner (Runner): The runner of the training process. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """We recover source model's parameter from ema model after validation. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + if self.evaluate_on_ema and self.evaluate_on_origin: + # Re-evaluate if evaluate on both ema and origin. + val_loop = runner.val_loop + + runner.model.eval() + for idx, data_batch in enumerate(val_loop.dataloader): + val_loop.run_iter(idx, data_batch) + + # compute metrics + origin_metrics = val_loop.evaluator.evaluate( + len(val_loop.dataloader.dataset)) + + for k, v in origin_metrics.items(): + runner.message_hub.update_scalar(f'val/{k}_origin', v) + + def before_test_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before test. + + Args: + runner (Runner): The runner of the training process. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + MMLogger.get_current_instance().info('Start testing on EMA model.') + else: + MMLogger.get_current_instance().info( + 'Start testing on the original model.') + + def after_test_epoch(self, + runner: Runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """We recover source model's parameter from ema model after test. + + Args: + runner (Runner): The runner of the testing process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on test dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + if self.evaluate_on_ema and self.evaluate_on_origin: + # Re-evaluate if evaluate on both ema and origin. + MMLogger.get_current_instance().info( + 'Start testing on the original model.') + test_loop = runner.test_loop + + runner.model.eval() + for idx, data_batch in enumerate(test_loop.dataloader): + test_loop.run_iter(idx, data_batch) + + # compute metrics + origin_metrics = test_loop.evaluator.evaluate( + len(test_loop.dataloader.dataset)) + + for k, v in origin_metrics.items(): + runner.message_hub.update_scalar(f'test/{k}_origin', v) + + def after_load_checkpoint(self, runner, checkpoint: dict) -> None: + """Resume ema parameters from checkpoint. + + Args: + runner (Runner): The runner of the testing process. + """ + from mmengine.runner.checkpoint import load_state_dict + if 'ema_state_dict' in checkpoint: + # The original model parameters are actually saved in ema + # field swap the weights back to resume ema state. + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) + self.load_ema_from_ckpt = True + + # Support load checkpoint without ema state dict. + else: + load_state_dict( + self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) + + @property + def _src_params(self): + if self.ema_model.update_buffers: + return itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + else: + return self.src_model.parameters() + + @property + def _ema_params(self): + if self.ema_model.update_buffers: + return itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + else: + return self.ema_model.module.parameters() diff --git a/mmpretrain/engine/hooks/margin_head_hooks.py b/mmpretrain/engine/hooks/margin_head_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..fbeae7a347453153ff4ab3bef958acb549623f6f --- /dev/null +++ b/mmpretrain/engine/hooks/margin_head_hooks.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved +import numpy as np +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models.heads import ArcFaceClsHead +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class SetAdaptiveMarginsHook(Hook): + r"""Set adaptive-margins in ArcFaceClsHead based on the power of + category-wise count. + + A PyTorch implementation of paper `Google Landmark Recognition 2020 + Competition Third Place Solution `_. + The margins will be + :math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`. + The `n` indicates the number of occurrences of a category. + + Args: + margin_min (float): Lower bound of margins. Defaults to 0.05. + margin_max (float): Upper bound of margins. Defaults to 0.5. + power (float): The power of category freqercy. Defaults to -0.25. + """ + + def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None: + self.margin_min = margin_min + self.margin_max = margin_max + self.margin_range = margin_max - margin_min + self.p = power + + def before_train(self, runner): + """change the margins in ArcFaceClsHead. + + Args: + runner (obj: `Runner`): Runner. + """ + model = runner.model + if is_model_wrapper(model): + model = model.module + + if (hasattr(model, 'head') + and not isinstance(model.head, ArcFaceClsHead)): + raise ValueError( + 'Hook ``SetFreqPowAdvMarginsHook`` could only be used ' + f'for ``ArcFaceClsHead``, but get {type(model.head)}') + + # generate margins base on the dataset. + gt_labels = runner.train_dataloader.dataset.get_gt_labels() + label_count = np.bincount(gt_labels) + label_count[label_count == 0] = 1 # At least one occurrence + pow_freq = np.power(label_count, self.p) + + min_f, max_f = pow_freq.min(), pow_freq.max() + normized_pow_freq = (pow_freq - min_f) / (max_f - min_f) + margins = normized_pow_freq * self.margin_range + self.margin_min + + assert len(margins) == runner.model.head.num_classes + + model.head.set_margins(margins) diff --git a/mmpretrain/engine/hooks/precise_bn_hook.py b/mmpretrain/engine/hooks/precise_bn_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb0e4c419e4ed2af23574769815aaecbcd629c0 --- /dev/null +++ b/mmpretrain/engine/hooks/precise_bn_hook.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501 +# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501 + +import itertools +import logging +from typing import List, Optional, Sequence, Union + +import mmengine +import torch +import torch.nn as nn +from mmengine.hooks import Hook +from mmengine.logging import print_log +from mmengine.model import is_model_wrapper +from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner +from mmengine.utils import ProgressBar +from torch.functional import Tensor +from torch.nn import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm +from torch.utils.data import DataLoader + +from mmpretrain.registry import HOOKS + +DATA_BATCH = Optional[Sequence[dict]] + + +def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]: + """Performs the scaled all_reduce operation on the provided tensors. + + The input tensors are modified in-place. Currently supports only the sum + reduction operator. The reduced values are scaled by the inverse size of + the process group. + + Args: + tensors (List[torch.Tensor]): The tensors to process. + num_gpus (int): The number of gpus to use + Returns: + List[torch.Tensor]: The processed tensors. + """ + # There is no need for reduction in the single-proc case + if num_gpus == 1: + return tensors + # Queue the reductions + reductions = [] + for tensor in tensors: + reduction = torch.distributed.all_reduce(tensor, async_op=True) + reductions.append(reduction) + # Wait for reductions to finish + for reduction in reductions: + reduction.wait() + # Scale the results + for tensor in tensors: + tensor.mul_(1.0 / num_gpus) + return tensors + + +@torch.no_grad() +def update_bn_stats( + model: nn.Module, + loader: DataLoader, + num_samples: int = 8192, + logger: Optional[Union[logging.Logger, str]] = None) -> None: + """Computes precise BN stats on training data. + + Args: + model (nn.module): The model whose bn stats will be recomputed. + loader (DataLoader): PyTorch dataloader._dataloader + num_samples (int): The number of samples to update the bn stats. + Defaults to 8192. + logger (logging.Logger or str, optional): If the type of logger is + ``logging.Logger``, we directly use logger to log messages. + Some special loggers are: + - "silent": No message will be printed. + - "current": Use latest created logger to log message. + - other str: Instance name of logger. The corresponding logger + will log message if it has been created, otherwise will raise a + `ValueError`. + - None: The `print()` method will be used to print log messages. + """ + if is_model_wrapper(model): + model = model.module + + # get dist info + rank, world_size = mmengine.dist.get_dist_info() + # Compute the number of mini-batches to use, if the size of dataloader is + # less than num_iters, use all the samples in dataloader. + num_iter = num_samples // (loader.batch_size * world_size) + num_iter = min(num_iter, len(loader)) + # Retrieve the BN layers + bn_layers = [ + m for m in model.modules() + if m.training and isinstance(m, (_BatchNorm)) + ] + if len(bn_layers) == 0: + print_log('No BN found in model', logger=logger, level=logging.WARNING) + return + print_log( + f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger) + + # Finds all the other norm layers with training=True. + other_norm_layers = [ + m for m in model.modules() + if m.training and isinstance(m, (_InstanceNorm, GroupNorm)) + ] + if len(other_norm_layers) > 0: + print_log( + 'IN/GN stats will not be updated in PreciseHook.', + logger=logger, + level=logging.INFO) + + # Initialize BN stats storage for computing + # mean(mean(batch)) and mean(var(batch)) + running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers] + running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers] + # Remember momentum values + momentums = [bn.momentum for bn in bn_layers] + # Set momentum to 1.0 to compute BN stats that reflect the current batch + for bn in bn_layers: + bn.momentum = 1.0 + # Average the BN stats for each BN layer over the batches + if rank == 0: + prog_bar = ProgressBar(num_iter) + + for data in itertools.islice(loader, num_iter): + data = model.data_preprocessor(data, False) + model(**data) + + for i, bn in enumerate(bn_layers): + running_means[i] += bn.running_mean / num_iter + running_vars[i] += bn.running_var / num_iter + if rank == 0: + prog_bar.update() + + # Sync BN stats across GPUs (no reduction if 1 GPU used) + running_means = scaled_all_reduce(running_means, world_size) + running_vars = scaled_all_reduce(running_vars, world_size) + # Set BN stats and restore original momentum values + for i, bn in enumerate(bn_layers): + bn.running_mean = running_means[i] + bn.running_var = running_vars[i] + bn.momentum = momentums[i] + + +@HOOKS.register_module() +class PreciseBNHook(Hook): + """Precise BN hook. + + Recompute and update the batch norm stats to make them more precise. During + training both BN stats and the weight are changing after every iteration, + so the running average can not precisely reflect the actual stats of the + current model. + + With this hook, the BN stats are recomputed with fixed weights, to make the + running average more precise. Specifically, it computes the true average of + per-batch mean/variance instead of the running average. See Sec. 3 of the + paper `Rethinking Batch in BatchNorm ` + for details. + + This hook will update BN stats, so it should be executed before + ``CheckpointHook`` and ``EMAHook``, generally set its priority to + "ABOVE_NORMAL". + + Args: + num_samples (int): The number of samples to update the bn stats. + Defaults to 8192. + interval (int): Perform precise bn interval. If the train loop is + `EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the + train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is + 'iter'. Defaults to 1. + """ + + def __init__(self, num_samples: int = 8192, interval: int = 1) -> None: + assert interval > 0 and num_samples > 0, "'interval' and " \ + "'num_samples' must be bigger than 0." + + self.interval = interval + self.num_samples = num_samples + + def _perform_precise_bn(self, runner: Runner) -> None: + """perform precise bn.""" + print_log( + f'Running Precise BN for {self.num_samples} samples...', + logger=runner.logger) + update_bn_stats( + runner.model, + runner.train_loop.dataloader, + self.num_samples, + logger=runner.logger) + print_log('Finish Precise BN, BN stats updated.', logger=runner.logger) + + def after_train_epoch(self, runner: Runner) -> None: + """Calculate prcise BN and broadcast BN stats across GPUs. + + Args: + runner (obj:`Runner`): The runner of the training process. + """ + # if use `EpochBasedTrainLoop``, do perform precise every + # `self.interval` epochs. + if isinstance(runner.train_loop, + EpochBasedTrainLoop) and self.every_n_epochs( + runner, self.interval): + self._perform_precise_bn(runner) + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: + """Calculate prcise BN and broadcast BN stats across GPUs. + + Args: + runner (obj:`Runner`): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. + """ + # if use `IterBasedTrainLoop``, do perform precise every + # `self.interval` iters. + if isinstance(runner.train_loop, + IterBasedTrainLoop) and self.every_n_train_iters( + runner, self.interval): + self._perform_precise_bn(runner) diff --git a/mmpretrain/engine/hooks/retriever_hooks.py b/mmpretrain/engine/hooks/retriever_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd7c7aaff3175491b1ea1508e33b07b7c2ea8d4 --- /dev/null +++ b/mmpretrain/engine/hooks/retriever_hooks.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved +import warnings + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models import BaseRetriever +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class PrepareProtoBeforeValLoopHook(Hook): + """The hook to prepare the prototype in retrievers. + + Since the encoders of the retriever changes during training, the prototype + changes accordingly. So the `prototype_vecs` needs to be regenerated before + validation loop. + """ + + def before_val(self, runner) -> None: + model = runner.model + if is_model_wrapper(model): + model = model.module + + if isinstance(model, BaseRetriever): + if hasattr(model, 'prepare_prototype'): + model.prepare_prototype() + else: + warnings.warn( + 'Only the `mmpretrain.models.retrievers.BaseRetriever` ' + 'can execute `PrepareRetrieverPrototypeHook`, but got ' + f'`{type(model)}`') diff --git a/mmpretrain/engine/hooks/simsiam_hook.py b/mmpretrain/engine/hooks/simsiam_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..fabc4faca02bb78b92c39de68fa8a18e56d544f5 --- /dev/null +++ b/mmpretrain/engine/hooks/simsiam_hook.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class SimSiamHook(Hook): + """Hook for SimSiam. + + This hook is for SimSiam to fix learning rate of predictor. + + Args: + fix_pred_lr (bool): whether to fix the lr of predictor or not. + lr (float): the value of fixed lr. + adjust_by_epoch (bool, optional): whether to set lr by epoch or iter. + Defaults to True. + """ + + def __init__(self, + fix_pred_lr: bool, + lr: float, + adjust_by_epoch: Optional[bool] = True) -> None: + self.fix_pred_lr = fix_pred_lr + self.lr = lr + self.adjust_by_epoch = adjust_by_epoch + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """fix lr of predictor by iter.""" + if self.adjust_by_epoch: + return + else: + if self.fix_pred_lr: + for param_group in runner.optim_wrapper.optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = self.lr + + def before_train_epoch(self, runner) -> None: + """fix lr of predictor by epoch.""" + if self.fix_pred_lr: + for param_group in runner.optim_wrapper.optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = self.lr diff --git a/mmpretrain/engine/hooks/swav_hook.py b/mmpretrain/engine/hooks/swav_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..71c82ad1e0c47114cccdd90f26a3d6c086e36d18 --- /dev/null +++ b/mmpretrain/engine/hooks/swav_hook.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, List, Optional, Sequence + +import torch +from mmengine.dist import get_rank, get_world_size, is_distributed +from mmengine.hooks import Hook +from mmengine.logging import MMLogger + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class SwAVHook(Hook): + """Hook for SwAV. + + This hook builds the queue in SwAV according to ``epoch_queue_starts``. + The queue will be saved in ``runner.work_dir`` or loaded at start epoch + if the path folder has queues saved before. + + Args: + batch_size (int): the batch size per GPU for computing. + epoch_queue_starts (int, optional): from this epoch, starts to use the + queue. Defaults to 15. + crops_for_assign (list[int], optional): list of crops id used for + computing assignments. Defaults to [0, 1]. + feat_dim (int, optional): feature dimension of output vector. + Defaults to 128. + queue_length (int, optional): length of the queue (0 for no queue). + Defaults to 0. + interval (int, optional): the interval to save the queue. + Defaults to 1. + frozen_layers_cfg (dict, optional): Dict to config frozen layers. + The key-value pair is layer name and its frozen iters. If frozen, + the layers don't need gradient. Defaults to dict(). + """ + + def __init__( + self, + batch_size: int, + epoch_queue_starts: Optional[int] = 15, + crops_for_assign: Optional[List[int]] = [0, 1], + feat_dim: Optional[int] = 128, + queue_length: Optional[int] = 0, + interval: Optional[int] = 1, + frozen_layers_cfg: Optional[Dict] = dict() + ) -> None: + self.batch_size = batch_size * get_world_size() + self.epoch_queue_starts = epoch_queue_starts + self.crops_for_assign = crops_for_assign + self.feat_dim = feat_dim + self.queue_length = queue_length + self.interval = interval + self.frozen_layers_cfg = frozen_layers_cfg + self.requires_grad = True + self.queue = None + + def before_run(self, runner) -> None: + """Check whether the queues exist locally or not.""" + if is_distributed(): + self.queue_path = osp.join(runner.work_dir, + 'queue' + str(get_rank()) + '.pth') + else: + self.queue_path = osp.join(runner.work_dir, 'queue.pth') + + # load the queues if queues exist locally + if osp.isfile(self.queue_path): + self.queue = torch.load(self.queue_path)['queue'] + get_ori_model(runner.model).head.loss_module.queue = self.queue + MMLogger.get_current_instance().info( + f'Load queue from file: {self.queue_path}') + + # the queue needs to be divisible by the batch size + self.queue_length -= self.queue_length % self.batch_size + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """Freeze layers before specific iters according to the config.""" + for layer, frozen_iters in self.frozen_layers_cfg.items(): + if runner.iter < frozen_iters and self.requires_grad: + self.requires_grad = False + for name, p in get_ori_model(runner.model).named_parameters(): + if layer in name: + p.requires_grad = False + elif runner.iter >= frozen_iters and not self.requires_grad: + self.requires_grad = True + for name, p in get_ori_model(runner.model).named_parameters(): + if layer in name: + p.requires_grad = True + + def before_train_epoch(self, runner) -> None: + """Check the queues' state.""" + # optionally starts a queue + if self.queue_length > 0 \ + and runner.epoch >= self.epoch_queue_starts \ + and self.queue is None: + self.queue = torch.zeros( + len(self.crops_for_assign), + self.queue_length // runner.world_size, + self.feat_dim, + ).cuda() + + # set the boolean type of use_the_queue + get_ori_model(runner.model).head.loss_module.queue = self.queue + get_ori_model(runner.model).head.loss_module.use_queue = False + + def after_train_epoch(self, runner) -> None: + """Save the queues locally.""" + self.queue = get_ori_model(runner.model).head.loss_module.queue + + if self.queue is not None and self.every_n_epochs( + runner, self.interval): + torch.save({'queue': self.queue}, self.queue_path) diff --git a/mmpretrain/engine/hooks/switch_recipe_hook.py b/mmpretrain/engine/hooks/switch_recipe_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..914b9572eb22d2cd2f54c519273c86baf2e0894d --- /dev/null +++ b/mmpretrain/engine/hooks/switch_recipe_hook.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from copy import deepcopy + +from mmcv.transforms import Compose +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models.utils import RandomBatchAugment +from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS + + +@HOOKS.register_module() +class SwitchRecipeHook(Hook): + """switch recipe during the training loop, including train pipeline, batch + augments and loss currently. + + Args: + schedule (list): Every item of the schedule list should be a dict, and + the dict should have ``action_epoch`` and some of + ``train_pipeline``, ``train_augments`` and ``loss`` keys: + + - ``action_epoch`` (int): switch training recipe at which epoch. + - ``train_pipeline`` (list, optional): The new data pipeline of the + train dataset. If not specified, keep the original settings. + - ``batch_augments`` (dict | None, optional): The new batch + augmentations of during training. See :mod:`Batch Augmentations + ` for more details. + If None, disable batch augmentations. If not specified, keep the + original settings. + - ``loss`` (dict, optional): The new loss module config. If not + specified, keep the original settings. + + Example: + To use this hook in config files. + + .. code:: python + + custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict( + action_epoch=30, + train_pipeline=pipeline_after_30e, + batch_augments=batch_augments_after_30e, + loss=loss_after_30e, + ), + dict( + action_epoch=60, + # Disable batch augmentations after 60e + # and keep other settings. + batch_augments=None, + ), + ] + ) + ] + """ + priority = 'NORMAL' + + def __init__(self, schedule): + recipes = {} + for recipe in schedule: + assert 'action_epoch' in recipe, \ + 'Please set `action_epoch` in every item ' \ + 'of the `schedule` in the SwitchRecipeHook.' + recipe = deepcopy(recipe) + if 'train_pipeline' in recipe: + recipe['train_pipeline'] = Compose(recipe['train_pipeline']) + if 'batch_augments' in recipe: + batch_augments = recipe['batch_augments'] + if isinstance(batch_augments, dict): + batch_augments = RandomBatchAugment(**batch_augments) + recipe['batch_augments'] = batch_augments + if 'loss' in recipe: + loss = recipe['loss'] + if isinstance(loss, dict): + loss = MODELS.build(loss) + recipe['loss'] = loss + + action_epoch = recipe.pop('action_epoch') + assert action_epoch not in recipes, \ + f'The `action_epoch` {action_epoch} is repeated ' \ + 'in the SwitchRecipeHook.' + recipes[action_epoch] = recipe + self.schedule = OrderedDict(sorted(recipes.items())) + + def before_train(self, runner) -> None: + """before run setting. If resume form a checkpoint, do all switch + before the current epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + """ + if runner._resume: + for action_epoch, recipe in self.schedule.items(): + if action_epoch >= runner.epoch + 1: + break + self._do_switch(runner, recipe, + f' (resume recipe of epoch {action_epoch})') + + def before_train_epoch(self, runner): + """do before train epoch.""" + recipe = self.schedule.get(runner.epoch + 1, None) + if recipe is not None: + self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}') + + def _do_switch(self, runner, recipe, extra_info=''): + """do the switch aug process.""" + if 'batch_augments' in recipe: + self._switch_batch_augments(runner, recipe['batch_augments']) + runner.logger.info(f'Switch batch augments{extra_info}.') + + if 'train_pipeline' in recipe: + self._switch_train_pipeline(runner, recipe['train_pipeline']) + runner.logger.info(f'Switch train pipeline{extra_info}.') + + if 'loss' in recipe: + self._switch_loss(runner, recipe['loss']) + runner.logger.info(f'Switch loss{extra_info}.') + + @staticmethod + def _switch_batch_augments(runner, batch_augments): + """switch the train augments.""" + model = runner.model + if is_model_wrapper(model): + model = model.module + + model.data_preprocessor.batch_augments = batch_augments + + @staticmethod + def _switch_train_pipeline(runner, train_pipeline): + """switch the train loader dataset pipeline.""" + + def switch_pipeline(dataset, pipeline): + if hasattr(dataset, 'pipeline'): + # for usual dataset + dataset.pipeline = pipeline + elif hasattr(dataset, 'datasets'): + # for concat dataset wrapper + for ds in dataset.datasets: + switch_pipeline(ds, pipeline) + elif hasattr(dataset, 'dataset'): + # for other dataset wrappers + switch_pipeline(dataset.dataset, pipeline) + else: + raise RuntimeError( + 'Cannot access the `pipeline` of the dataset.') + + train_loader = runner.train_loop.dataloader + switch_pipeline(train_loader.dataset, train_pipeline) + + # To restart the iterator of dataloader when `persistent_workers=True` + train_loader._iterator = None + + @staticmethod + def _switch_loss(runner, loss_module): + """switch the loss module.""" + model = runner.model + if is_model_wrapper(model, MODEL_WRAPPERS): + model = model.module + + if hasattr(model, 'loss_module'): + model.loss_module = loss_module + elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'): + model.head.loss_module = loss_module + else: + raise RuntimeError('Cannot access the `loss_module` of the model.') diff --git a/mmpretrain/engine/hooks/visualization_hook.py b/mmpretrain/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..64d2230a79db971bef78d77bcf80c40365bddb15 --- /dev/null +++ b/mmpretrain/engine/hooks/visualization_hook.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import os.path as osp +from typing import Optional, Sequence + +from mmengine.fileio import join_path +from mmengine.hooks import Hook +from mmengine.runner import EpochBasedTrainLoop, Runner +from mmengine.visualization import Visualizer + +from mmpretrain.registry import HOOKS +from mmpretrain.structures import DataSample + + +@HOOKS.register_module() +class VisualizationHook(Hook): + """Classification Visualization Hook. Used to visualize validation and + testing prediction results. + + - If ``out_dir`` is specified, all storage backends are ignored + and save the image to the ``out_dir``. + - If ``show`` is True, plot the result image in a window, please + confirm you are able to access the graphical interface. + + Args: + enable (bool): Whether to enable this hook. Defaults to False. + interval (int): The interval of samples to visualize. Defaults to 5000. + show (bool): Whether to display the drawn image. Defaults to False. + out_dir (str, optional): directory where painted images will be saved + in the testing process. If None, handle with the backends of the + visualizer. Defaults to None. + **kwargs: other keyword arguments of + :meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`. + """ + + def __init__(self, + enable=False, + interval: int = 5000, + show: bool = False, + out_dir: Optional[str] = None, + **kwargs): + self._visualizer: Visualizer = Visualizer.get_current_instance() + + self.enable = enable + self.interval = interval + self.show = show + self.out_dir = out_dir + + self.draw_args = {**kwargs, 'show': show} + + def _draw_samples(self, + batch_idx: int, + data_batch: dict, + data_samples: Sequence[DataSample], + step: int = 0) -> None: + """Visualize every ``self.interval`` samples from a data batch. + + Args: + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DataSample`]): Outputs from model. + step (int): Global step value to record. Defaults to 0. + """ + if self.enable is False: + return + + batch_size = len(data_samples) + images = data_batch['inputs'] + start_idx = batch_size * batch_idx + end_idx = start_idx + batch_size + + # The first index divisible by the interval, after the start index + first_sample_id = math.ceil(start_idx / self.interval) * self.interval + + for sample_id in range(first_sample_id, end_idx, self.interval): + image = images[sample_id - start_idx] + image = image.permute(1, 2, 0).cpu().numpy().astype('uint8') + + data_sample = data_samples[sample_id - start_idx] + if 'img_path' in data_sample: + # osp.basename works on different platforms even file clients. + sample_name = osp.basename(data_sample.get('img_path')) + else: + sample_name = str(sample_id) + + draw_args = self.draw_args + if self.out_dir is not None: + draw_args['out_file'] = join_path(self.out_dir, + f'{sample_name}_{step}.png') + + self._visualizer.visualize_cls( + image=image, + data_sample=data_sample, + step=step, + name=sample_name, + **self.draw_args, + ) + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DataSample]) -> None: + """Visualize every ``self.interval`` samples during validation. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DataSample`]): Outputs from model. + """ + if isinstance(runner.train_loop, EpochBasedTrainLoop): + step = runner.epoch + else: + step = runner.iter + + self._draw_samples(batch_idx, data_batch, outputs, step=step) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DataSample]) -> None: + """Visualize every ``self.interval`` samples during test. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. + """ + self._draw_samples(batch_idx, data_batch, outputs, step=0) diff --git a/mmpretrain/engine/hooks/warmup_param_hook.py b/mmpretrain/engine/hooks/warmup_param_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d8918dbbcb9cf5d12c252621908f0b6c1f251 --- /dev/null +++ b/mmpretrain/engine/hooks/warmup_param_hook.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator as op +from typing import Any, Optional, Union + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class WarmupParamHook(Hook): + """This is a hook used for changing the parameters other than optimizations + that need to warmup inside the module. + + This hook can extend with more detailed warmup rule if necessary. + + Args: + param_name (str): The parameter name that needs to be altered. + module_name (str): Module name that belongs to the model. Such as + `head`, `head.loss`, etc. + warmup_epochs (int): The warmup epochs for this parameter. + """ + + def __init__( + self, + param_name: str, + module_name: str, + warmup_epochs: int, + ) -> None: + self.param_name = param_name + self.warmup_epochs = warmup_epochs + # getter for module which saves the changed parameter + self.module_getter = op.attrgetter(module_name) + + def get_param(self, runner) -> Any: + """Get the parameter.""" + try: + module = self.module_getter(get_ori_model(runner.model)) + return getattr(module, self.param_name) + except AttributeError as e: + raise AttributeError(f'{e}. Please check hook settings.') + + def set_param(self, runner, value) -> None: + """Set the parameter.""" + try: + module = self.module_getter(get_ori_model(runner.model)) + setattr(module, self.param_name, value) + except AttributeError as e: + raise AttributeError(f'{e}. Please check hook settings.') + + def before_train(self, runner) -> None: + """Get the original value before train.""" + self.ori_val = self.get_param(runner) + + def before_train_iter( + self, + runner, + batch_idx: int, + data_batch: Optional[Union[dict, tuple, list]] = None) -> None: + """Set the warmup value before each train iter.""" + cur_iter = runner.iter + iters_per_epoch = runner.max_iters / runner.max_epochs + new_val = self.ori_val * min( + 1, cur_iter / (self.warmup_epochs * iters_per_epoch)) + self.set_param(runner, new_val) diff --git a/mmpretrain/engine/optimizers/__init__.py b/mmpretrain/engine/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd53a37630b2a0dfbb69b1020518b9ec4ff03715 --- /dev/null +++ b/mmpretrain/engine/optimizers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adan_t import Adan +from .lamb import Lamb +from .lars import LARS +from .layer_decay_optim_wrapper_constructor import \ + LearningRateDecayOptimWrapperConstructor + +__all__ = ['Lamb', 'Adan', 'LARS', 'LearningRateDecayOptimWrapperConstructor'] diff --git a/mmpretrain/engine/optimizers/adan_t.py b/mmpretrain/engine/optimizers/adan_t.py new file mode 100644 index 0000000000000000000000000000000000000000..571a71b6fe561fb33053af2fd6d2161a775918e4 --- /dev/null +++ b/mmpretrain/engine/optimizers/adan_t.py @@ -0,0 +1,312 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Adan(Optimizer): + """Implements a pytorch variant of Adan. + + Adan was proposed in + Adan : Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. # noqa + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize + or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used + for computing running averages of gradient. + (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self): + """Performs a single optimization step.""" + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + group['eps'] + + clip_global_grad_norm = \ + torch.clamp(max_grad_norm / global_grad_norm, max=1.0) + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'pre_grad' not in state or group['step'] == 1: + # at first step grad wouldn't be clipped + # by `clip_global_grad_norm` + # this is only to simplify implementation + state['pre_grad'] = p.grad + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + pre_grads.append(state['pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + pre_grads=pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + if group['foreach']: + copy_grads = _multi_tensor_adan(**kwargs) + else: + copy_grads = _single_tensor_adan(**kwargs) + + for p, copy_grad in zip(params_with_grad, copy_grads): + self.state[p]['pre_grad'] = copy_grad + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + copy_grads = [] + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + pre_grad = pre_grads[i] + + grad = grad.mul_(clip_global_grad_norm) + copy_grads.append(grad.clone()) + + diff = grad - pre_grad + update = grad + beta2 * diff + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t + exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t + + denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps) + update = exp_avg / bias_correction1 + update.add_(beta2 * exp_avg_diff / bias_correction2).div_(denom) + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.add_(update, alpha=-lr) + else: + param.add_(update, alpha=-lr) + param.div_(1 + lr * weight_decay) + return copy_grads + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if clip_global_grad_norm < 1.0: + torch._foreach_mul_(grads, clip_global_grad_norm.item()) + copy_grads = [g.clone() for g in grads] + + diff = torch._foreach_sub(grads, pre_grads) + # NOTE: line below while looking identical gives different result, + # due to float precision errors. + # using mul+add produces identical results to single-tensor, + # using add+alpha doesn't + # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2)) + update = torch._foreach_add(grads, diff, alpha=beta2) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t + + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_( + exp_avg_sqs, update, update, value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + update = torch._foreach_div(exp_avgs, bias_correction1) + # NOTE: same issue as above. + # beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) # noqa + # using faster version by default. uncomment for tests to pass + # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) # noqa + torch._foreach_add_( + update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2)) + torch._foreach_div_(update, denom) + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + else: + torch._foreach_add_(params, update, alpha=-lr) + torch._foreach_div_(params, 1 + lr * weight_decay) + return copy_grads diff --git a/mmpretrain/engine/optimizers/lamb.py b/mmpretrain/engine/optimizers/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..0b44a1c168e03fa7f569388beec206fe68c64749 --- /dev/null +++ b/mmpretrain/engine/optimizers/lamb.py @@ -0,0 +1,228 @@ +"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb. + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/ +2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/ +LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb +is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or +cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support +PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Lamb(Optimizer): + """A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer. + + This class is copied from `timm`_. The LAMB was proposed in `Large Batch + Optimization for Deep Learning - Training BERT in 76 minutes`_. + + .. _timm: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + """ # noqa: E501 + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False): + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + trust_clip=trust_clip, + always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + 'Lamb does not support sparse gradients, consider ' + 'SparseAdam instead.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars + # when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or + # pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1**group['step'] + bias_correction2 = 1 - beta2**group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on + # parameters that are + # excluded from weight decay, unless always_adapt == True, + # then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not + # working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/mmpretrain/engine/optimizers/lars.py b/mmpretrain/engine/optimizers/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..5e388878374e3d1e7408861a5f1830b00df5664b --- /dev/null +++ b/mmpretrain/engine/optimizers/lars.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable + +import torch +from torch.optim.optimizer import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class LARS(Optimizer): + """Implements layer-wise adaptive rate scaling for SGD. + + Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. + `Large Batch Training of Convolutional Networks: + `_. + + Args: + params (Iterable): Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): Base learning rate. + momentum (float): Momentum factor. Defaults to 0. + weight_decay (float): Weight decay (L2 penalty). Defaults to 0. + dampening (float): Dampening for momentum. Defaults to 0. + eta (float): LARS coefficient. Defaults to 0.001. + nesterov (bool): Enables Nesterov momentum. Defaults to False. + eps (float): A small number to avoid dviding zero. Defaults to 1e-8. + + Example: + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, + >>> weight_decay=1e-4, eta=1e-3) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + """ + + def __init__(self, + params: Iterable, + lr: float, + momentum: float = 0, + weight_decay: float = 0, + dampening: float = 0, + eta: float = 0.001, + nesterov: bool = False, + eps: float = 1e-8) -> None: + if not isinstance(lr, float) and lr < 0.0: + raise ValueError(f'Invalid learning rate: {lr}') + if momentum < 0.0: + raise ValueError(f'Invalid momentum value: {momentum}') + if weight_decay < 0.0: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if eta < 0.0: + raise ValueError(f'Invalid LARS coefficient value: {eta}') + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + eta=eta) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + 'Nesterov momentum requires a momentum and zero dampening') + + self.eps = eps + super().__init__(params, defaults) + + def __setstate__(self, state) -> None: + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None) -> torch.Tensor: + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + eta = group['eta'] + nesterov = group['nesterov'] + lr = group['lr'] + lars_exclude = group.get('lars_exclude', False) + + for p in group['params']: + if p.grad is None: + continue + + d_p = p.grad + + if lars_exclude: + local_lr = 1. + else: + weight_norm = torch.norm(p).item() + grad_norm = torch.norm(d_p).item() + if weight_norm != 0 and grad_norm != 0: + # Compute local learning rate for this layer + local_lr = eta * weight_norm / \ + (grad_norm + weight_decay * weight_norm + self.eps) + else: + local_lr = 1. + + actual_lr = local_lr * lr + d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = \ + torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + p.add_(-d_p) + + return loss diff --git a/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..09c6abc54a9f49cc789bf91d2bf74b0ec68902c4 --- /dev/null +++ b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Callable, List, Optional + +from mmengine.logging import MMLogger +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from torch import nn +from torch.nn import GroupNorm, LayerNorm + +from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): + """Different learning rates are set for different layers of backbone. + + By default, each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``layer_decay_rate`` (float): The learning rate of a parameter will + multiply it by multiple times according to the layer depth of the + parameter. Usually, it's less than 1, so that the earlier layers will + have a lower learning rate. Defaults to 1. + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in normalization layers). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization layers. + - ``flat_decay_mult`` (float): It will be multiplied to the weight + decay for all one-dimensional parameters + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be + ignored. It should be a dict and may contain fields ``decay_mult``. + (The ``lr_mult`` is disabled in this constructor). + + Example: + + In the config file, you can use this constructor as below: + + .. code:: python + + optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=4e-3, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + constructor='LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + layer_decay_rate=0.75, # layer-wise lr decay factor + norm_decay_mult=0., + flat_decay_mult=0., + custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + """ + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + get_layer_depth: Optional[Callable] = None, + **kwargs) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (List[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + optimizer_cfg (dict): The configuration of optimizer. + prefix (str): The prefix of the module. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + logger = MMLogger.get_current_instance() + + # The model should have `get_layer_depth` method + if get_layer_depth is None and not hasattr(module, 'get_layer_depth'): + raise NotImplementedError('The layer-wise learning rate decay need' + f' the model {type(module)} has' + ' `get_layer_depth` method.') + else: + get_layer_depth = get_layer_depth or module.get_layer_depth + + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) + flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) + decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + param_name = prefix + name + if not param.requires_grad: + continue + + if self.base_wd is not None: + base_wd = self.base_wd + custom_key = next( + filter(lambda k: k in param_name, sorted_keys), None) + # custom parameters decay + if custom_key is not None: + custom_cfg = custom_keys[custom_key].copy() + decay_mult = custom_cfg.pop('decay_mult', 1.) + + param_group['weight_decay'] = base_wd * decay_mult + # add custom settings to param_group + param_group.update(custom_cfg) + # norm decay + elif is_norm and norm_decay_mult is not None: + param_group['weight_decay'] = base_wd * norm_decay_mult + # bias decay + elif name == 'bias' and bias_decay_mult is not None: + param_group['weight_decay'] = base_wd * bias_decay_mult + # flatten parameters decay + elif param.ndim == 1 and flat_decay_mult is not None: + param_group['weight_decay'] = base_wd * flat_decay_mult + else: + param_group['weight_decay'] = base_wd + + layer_id, max_id = get_layer_depth(param_name) + scale = decay_rate**(max_id - layer_id - 1) + param_group['lr'] = self.base_lr * scale + param_group['lr_scale'] = scale + param_group['layer_id'] = layer_id + param_group['param_name'] = param_name + + params.append(param_group) + + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}{child_name}.' + self.add_params( + params, + child_mod, + prefix=child_prefix, + get_layer_depth=get_layer_depth, + ) + + if prefix == '': + layer_params = defaultdict(list) + for param in params: + layer_params[param['layer_id']].append(param) + for layer_id, layer_params in layer_params.items(): + lr_scale = layer_params[0]['lr_scale'] + lr = layer_params[0]['lr'] + msg = [ + f'layer {layer_id} params ' + f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):' + ] + for param in layer_params: + msg.append(f'\t{param["param_name"]}: ' + f'weight_decay={param["weight_decay"]:.3g}') + logger.debug('\n'.join(msg)) diff --git a/mmpretrain/engine/runners/__init__.py b/mmpretrain/engine/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23206e1ea7c83fa1d547c677b3fe5203f8c5485f --- /dev/null +++ b/mmpretrain/engine/runners/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .retrieval_loop import RetrievalTestLoop, RetrievalValLoop + +__all__ = ['RetrievalTestLoop', 'RetrievalValLoop'] diff --git a/mmpretrain/engine/runners/retrieval_loop.py b/mmpretrain/engine/runners/retrieval_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..d15387eddeb9075c23949f95a77ed59006bb9a38 --- /dev/null +++ b/mmpretrain/engine/runners/retrieval_loop.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmengine.model import is_model_wrapper +from mmengine.runner import TestLoop, ValLoop, autocast + +from mmpretrain.registry import LOOPS + + +@LOOPS.register_module() +class RetrievalValLoop(ValLoop): + """Loop for multimodal retrieval val. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 valing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch val.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + img_size = self.dataloader.dataset.img_size + text_size = self.dataloader.dataset.text_size + with torch.no_grad(): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=img_size, + num_texts=text_size, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(img_size) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(text_size) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_val_epoch', metrics=metrics) + self.runner.call_hook('after_val') + return metrics + + +@LOOPS.register_module() +class RetrievalTestLoop(TestLoop): + """Loop for multimodal retrieval test. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 testing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + img_size = self.dataloader.dataset.img_size + text_size = self.dataloader.dataset.text_size + with torch.no_grad(): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=img_size, + num_texts=text_size, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(img_size) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(text_size) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + return metrics diff --git a/mmpretrain/engine/schedulers/__init__.py b/mmpretrain/engine/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68b6a5477b84a53e060e0e6d43fdac830adebffb --- /dev/null +++ b/mmpretrain/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .weight_decay_scheduler import CosineAnnealingWeightDecay + +__all__ = ['CosineAnnealingWeightDecay'] diff --git a/mmpretrain/engine/schedulers/weight_decay_scheduler.py b/mmpretrain/engine/schedulers/weight_decay_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7e725a4c3f53856cf848ed7e6a225a178b36ab98 --- /dev/null +++ b/mmpretrain/engine/schedulers/weight_decay_scheduler.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmengine.optim.scheduler import CosineAnnealingParamScheduler + +from mmpretrain.registry import PARAM_SCHEDULERS + + +class WeightDecaySchedulerMixin: + """A mixin class for learning rate schedulers.""" + + def __init__(self, optimizer, *args, **kwargs): + super().__init__(optimizer, 'weight_decay', *args, **kwargs) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin, + CosineAnnealingParamScheduler): + """Set the weight decay value of each parameter group using a cosine + annealing schedule. + + If the weight decay was set to be 0 initially, the weight decay value will + be 0 constantly during the training. + """ + + def _get_value(self) -> list: + """Compute value using chainable form of the scheduler.""" + + def _get_eta_min(base_value): + if self.eta_min_ratio is None: + return self.eta_min + return base_value * self.eta_min_ratio + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = group[self.param_name] + ( + base_value - _get_eta_min(base_value)) * ( + 1 - math.cos(math.pi / self.T_max)) / 2 + weight_decay_value_list.append(group_value) + return weight_decay_value_list + + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = ( + 1 + math.cos(math.pi * self.last_step / self.T_max)) / ( + 1 + math.cos(math.pi * + (self.last_step - 1) / self.T_max) + ) * (group[self.param_name] - + _get_eta_min(base_value)) + _get_eta_min(base_value) + weight_decay_value_list.append(group_value) + return weight_decay_value_list diff --git a/mmpretrain/evaluation/__init__.py b/mmpretrain/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f70dc226d30f7b8e4ee5a44ca163ad1ae04eabf5 --- /dev/null +++ b/mmpretrain/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .functional import * # noqa: F401,F403 +from .metrics import * # noqa: F401,F403 diff --git a/mmpretrain/evaluation/__pycache__/__init__.cpython-38.pyc b/mmpretrain/evaluation/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aba2f6036747eea59a71e2f97b0d349c7efe865 Binary files /dev/null and b/mmpretrain/evaluation/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/functional/__init__.py b/mmpretrain/evaluation/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef101fec61e72abc0eb90266d453b5b22331378d --- /dev/null +++ b/mmpretrain/evaluation/functional/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-38.pyc b/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e015862669f76e47347e9e68707daf29dd4316 Binary files /dev/null and b/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd46de7833ff5c2c5fb58a077eee8785710ff37c --- /dev/null +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .caption import COCOCaption +from .gqa import GQAAcc +from .multi_label import AveragePrecision, MultiLabelMetric +from .multi_task import MultiTasksMetric +from .nocaps import NocapsSave +from .retrieval import RetrievalAveragePrecision, RetrievalRecall +from .scienceqa import ScienceQAMetric +from .shape_bias_label import ShapeBiasMetric +from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric +from .visual_grounding_eval import VisualGroundingMetric +from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric +from .vqa import ReportVQA, VQAAcc + +__all__ = [ + 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', + 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', + 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', + 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', + 'RetrievalAveragePrecision', 'ShapeBiasMetric' +] diff --git a/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9db958ec9aaf01fab15a49fcb4deec9129e2e76 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a173efbe94c66d7d00aed22d2193a7f2345accde Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90a52a6dbf1f822d8588ca36acd97d01c164b825 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ef486926d8c574879dc2157742fbef4f490788d Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c60058032d33be8d9a1a3dc9cd9221d1e21be4f8 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec8b683f984d02ca339b0aec9bfd63133fdd258 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb731b44aae65616d9f1c859b41ec8a7063702ea Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03603ff44a999a0b7d7b332d807ef82c4a87be56 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5712509141dbd4b1b7b98dfd7635fb7c30d3be74 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dc181c42f1c2c8be0c81f6045373096ac9d5188 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d47e10ac5aaf94bd0051db1d1e0773f75c27319 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b49bb1c97051c449328bae888267596ca57cc75 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19fc0f5e2fbef1a99af875be6553634ce79f1226 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-38.pyc differ diff --git a/mmpretrain/evaluation/metrics/caption.py b/mmpretrain/evaluation/metrics/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bffabfa97a9c6faec7ecc0ffb6d9ba2f435b97 --- /dev/null +++ b/mmpretrain/evaluation/metrics/caption.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import tempfile +from typing import List, Optional + +from mmengine.evaluator import BaseMetric +from mmengine.utils import track_iter_progress + +from mmpretrain.registry import METRICS +from mmpretrain.utils import require + +try: + from pycocoevalcap.eval import COCOEvalCap + from pycocotools.coco import COCO +except ImportError: + COCOEvalCap = None + COCO = None + + +@METRICS.register_module() +class COCOCaption(BaseMetric): + """Coco Caption evaluation wrapper. + + Save the generated captions and transform into coco format. + Calling COCO API for caption metrics. + + Args: + ann_file (str): the path for the COCO format caption ground truth + json file, load for evaluations. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + @require('pycocoevalcap') + def __init__(self, + ann_file: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.ann_file = ann_file + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['caption'] = data_sample.get('pred_caption') + result['image_id'] = int(data_sample.get('image_id')) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + + with tempfile.TemporaryDirectory() as temp_dir: + + eval_result_file = save_result( + result=results, + result_dir=temp_dir, + filename='m4-caption_pred', + remove_duplicate='image_id', + ) + + coco_val = coco_caption_eval(eval_result_file, self.ann_file) + + return coco_val + + +def save_result(result, result_dir, filename, remove_duplicate=''): + """Saving predictions as json file for evaluation.""" + + # combine results from all processes + result_new = [] + + if remove_duplicate: + result_new = [] + id_list = [] + for res in track_iter_progress(result): + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + final_result_file_url = os.path.join(result_dir, '%s.json' % filename) + print(f'result file saved to {final_result_file_url}') + json.dump(result, open(final_result_file_url, 'w')) + + return final_result_file_url + + +def coco_caption_eval(results_file, ann_file): + """Evaluation between gt json and prediction json files.""" + # create coco object and coco_result object + coco = COCO(ann_file) + coco_result = coco.loadRes(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # make sure the image ids are the same + coco_eval.params['image_id'] = coco_result.getImgIds() + + # This will take some times at the first run + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + return coco_eval.eval diff --git a/mmpretrain/evaluation/metrics/gqa.py b/mmpretrain/evaluation/metrics/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e8b0725524839c5b0a15a8ba6fb4eed689e589 --- /dev/null +++ b/mmpretrain/evaluation/metrics/gqa.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, + _process_punctuation) +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class GQAAcc(BaseMetric): + """GQA Acc metric. + + Compute GQA accuracy. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'GQA' + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer + } + + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = self._process_answer(result['gt_answer']) + gqa_acc = 1 if pred_answer == gt_answer else 0 + acc.append(gqa_acc) + + accuracy = sum(acc) / len(acc) + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer) -> str: + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer diff --git a/mmpretrain/evaluation/metrics/multi_label.py b/mmpretrain/evaluation/metrics/multi_label.py new file mode 100644 index 0000000000000000000000000000000000000000..bd91aac4449c845fbed514ed5f800bd971236ade --- /dev/null +++ b/mmpretrain/evaluation/metrics/multi_label.py @@ -0,0 +1,599 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .single_label import _precision_recall_f1_support, to_tensor + + +@METRICS.register_module() +class MultiLabelMetric(BaseMetric): + r"""A collection of precision, recall, f1-score and support for + multi-label tasks. + + The collection of metrics is for single-label multi-class classification. + And all these metrics are based on the confusion matrix of every category: + + .. image:: ../../_static/image/confusion-matrix.png + :width: 60% + :align: center + + All metrics can be formulated use variables above: + + **Precision** is the fraction of correct predictions in all predictions: + + .. math:: + \text{Precision} = \frac{TP}{TP+FP} + + **Recall** is the fraction of correct predictions in all targets: + + .. math:: + \text{Recall} = \frac{TP}{TP+FN} + + **F1-score** is the harmonic mean of the precision and recall: + + .. math:: + \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}} + + **Support** is the number of samples: + + .. math:: + \text{Support} = TP + TN + FN + FP + + Args: + thr (float, optional): Predictions with scores under the threshold + are considered as negative. If None, the ``topk`` predictions will + be considered as positive. If the ``topk`` is also None, use + ``thr=0.5`` as default. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. If None, use ``thr`` to determine positive + predictions. If both ``thr`` and ``topk`` are not None, use + ``thr``. Defaults to None. + items (Sequence[str]): The detailed metric items to evaluate, select + from "precision", "recall", "f1-score" and "support". + Defaults to ``('precision', 'recall', 'f1-score')``. + average (str | None): How to calculate the final metrics from the + confusion matrix of every category. It supports three modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories and + calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import MultiLabelMetric + >>> # ------ The Basic Usage for category indices labels ------- + >>> y_pred = [[0], [1], [0, 1], [3]] + >>> y_true = [[0, 3], [0, 2], [1], [3]] + >>> # Output precision, recall, f1-score and support + >>> MultiLabelMetric.calculate( + ... y_pred, y_true, pred_indices=True, target_indices=True, num_classes=4) + (tensor(50.), tensor(50.), tensor(45.8333), tensor(6)) + >>> # ----------- The Basic Usage for one-hot labels ----------- + >>> y_pred = torch.tensor([[1, 1, 0, 0], + ... [1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [0, 1, 0, 0], + ... [0, 1, 0, 0]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 1, 1, 0], + ... [1, 0, 0, 0], + ... [1, 0, 0, 0]]) + >>> MultiLabelMetric.calculate(y_pred, y_true) + (tensor(43.7500), tensor(31.2500), tensor(33.3333), tensor(8)) + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.rand(y_true.size()) + >>> y_pred + tensor([[0.4575, 0.7335, 0.3934, 0.2572], + [0.1318, 0.1004, 0.8248, 0.6448], + [0.8349, 0.6294, 0.7896, 0.2061], + [0.4037, 0.7308, 0.6713, 0.8374], + [0.3779, 0.4836, 0.0313, 0.0067]]) + >>> # Calculate with different threshold. + >>> MultiLabelMetric.calculate(y_pred, y_true, thr=0.1) + (tensor(42.5000), tensor(75.), tensor(53.1746), tensor(8)) + >>> # Calculate with topk. + >>> MultiLabelMetric.calculate(y_pred, y_true, topk=1) + (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8)) + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_sampels = [ + ... DataSample().set_pred_score(pred).set_gt_score(gt) + ... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))] + >>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5)) + >>> evaluator.process(data_sampels) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision': 50.72898037055408, + 'multi-label/recall': 50.06836461357571, + 'multi-label/f1-score': 50.384466955258475 + } + >>> # Evaluate on each class by using topk strategy + >>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None)) + >>> evaluator.process(data_sampels) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5], + 'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27], + 'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25] + } + """ # noqa: E501 + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + thr: Optional[float] = None, + topk: Optional[int] = None, + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + + logger = MMLogger.get_current_instance() + if thr is None and topk is None: + thr = 0.5 + logger.warning('Neither thr nor k is given, set thr as 0.5 by ' + 'default.') + elif thr is not None and topk is not None: + logger.warning('Both thr and topk are given, ' + 'use threshold in favor of top-k.') + + self.thr = thr + self.topk = topk + self.average = average + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please choose from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(data_sample['gt_label'], + num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + metric_res = self.calculate( + pred, + target, + pred_indices=False, + target_indices=False, + average=self.average, + thr=self.thr, + topk=self.topk) + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + if self.thr: + suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}' + for k, v in pack_results(*metric_res).items(): + metrics[k + suffix] = v + else: + for k, v in pack_results(*metric_res).items(): + metrics[k + f'_top{self.topk}'] = v + + result_metrics = dict() + for k, v in metrics.items(): + if self.average is None: + result_metrics[k + '_classwise'] = v.detach().cpu().tolist() + elif self.average == 'macro': + result_metrics[k] = v.item() + else: + result_metrics[k + f'_{self.average}'] = v.item() + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + pred_indices: bool = False, + target_indices: bool = False, + average: Optional[str] = 'macro', + thr: Optional[float] = None, + topk: Optional[int] = None, + num_classes: Optional[int] = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + average (str | None): How to calculate the final metrics from + the confusion matrix of every category. It supports three + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories + and calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + num_classes (Optional, int): The number of classes. If the ``pred`` + is indices instead of onehot, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: A tensor for each metric. The shape is (1, ) if + ``average`` is not None, and (C, ) if ``average`` is None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + def _format_label(label, is_indices): + """format various label to torch.Tensor.""" + if isinstance(label, np.ndarray): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'array must be (N, num_classes).' + label = torch.from_numpy(label) + elif isinstance(label, torch.Tensor): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'tensor must be (N, num_classes).' + elif isinstance(label, Sequence): + if is_indices: + assert num_classes is not None, 'For index-type labels, ' \ + 'please specify `num_classes`.' + label = torch.stack([ + label_to_onehot(indices, num_classes) + for indices in label + ]) + else: + label = torch.stack( + [to_tensor(onehot) for onehot in label]) + else: + raise TypeError( + 'The `pred` and `target` must be type of torch.tensor or ' + f'np.ndarray or sequence but get {type(label)}.') + return label + + pred = _format_label(pred, pred_indices) + target = _format_label(target, target_indices).long() + + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + if num_classes is not None: + assert pred.size(1) == num_classes, \ + f'The shape of `pred` ({pred.shape}) '\ + f"doesn't match the num_classes ({num_classes})." + num_classes = pred.size(1) + + thr = 0.5 if (thr is None and topk is None) else thr + + if thr is not None: + # a label is predicted positive if larger than thr + pos_inds = (pred >= thr).long() + else: + # top-k labels will be predicted positive for any example + _, topk_indices = pred.topk(topk) + pos_inds = torch.zeros_like(pred).scatter_(1, topk_indices, 1) + pos_inds = pos_inds.long() + + return _precision_recall_f1_support(pos_inds, target, average) + + +def _average_precision(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + r"""Calculate the average precision for a single class. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + pred (torch.Tensor): The model prediction with shape + ``(N, num_classes)``. + target (torch.Tensor): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + torch.Tensor: average precision result. + """ + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + # a small value for division by zero errors + eps = torch.finfo(torch.float32).eps + + # get rid of -1 target such as difficult sample + # that is not wanted in evaluation results. + valid_index = target > -1 + pred = pred[valid_index] + target = target[valid_index] + + # sort examples + sorted_pred_inds = torch.argsort(pred, dim=0, descending=True) + sorted_target = target[sorted_pred_inds] + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = torch.cumsum(pos_inds, 0) + total_pos = tps[-1].item() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device) + pred_pos_nums[pred_pos_nums < eps] = eps + + tps[torch.logical_not(pos_inds)] = 0 + precision = tps / pred_pos_nums.float() + ap = torch.sum(precision, 0) / max(total_pos, eps) + return ap + + +@METRICS.register_module() +class AveragePrecision(BaseMetric): + r"""Calculate the average precision with respect of classes. + + AveragePrecision (AP) summarizes a precision-recall curve as the weighted + mean of maximum precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + average (str | None): How to calculate the final metrics from + every category. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. The result of this mode + is also called **mAP**. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + References + ---------- + 1. `Wikipedia entry for the Average precision + `_ + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import AveragePrecision + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], + ... [0.1, 0.2, 0.2, 0.1], + ... [0.7, 0.5, 0.9, 0.3], + ... [0.8, 0.1, 0.1, 0.2]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 0, 0, 0]]) + >>> AveragePrecision.calculate(y_pred, y_true) + tensor(70.833) + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_pred_score(i).set_gt_score(j) + ... for i, j in zip(y_pred, y_true) + ... ] + >>> evaluator = Evaluator(metrics=AveragePrecision()) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(5) + {'multi-label/mAP': 70.83333587646484} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=AveragePrecision(average=None)) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(5) + {'multi-label/AP_classwise': [100., 83.33, 100., 0.]} + """ + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.average = average + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(data_sample['gt_label'], + num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + + # concat + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + ap = self.calculate(pred, target, self.average) + + result_metrics = dict() + + if self.average is None: + result_metrics['AP_classwise'] = ap.detach().cpu().tolist() + else: + result_metrics['mAP'] = ap.item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[torch.Tensor, np.ndarray], + target: Union[torch.Tensor, np.ndarray], + average: Optional[str] = 'macro') -> torch.Tensor: + r"""Calculate the average precision for a single class. + + Args: + pred (torch.Tensor | np.ndarray): The model predictions with + shape ``(N, num_classes)``. + target (torch.Tensor | np.ndarray): The target of predictions + with shape ``(N, num_classes)``. + average (str | None): The average method. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. The result of this mode + is also called mAP. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + + Returns: + torch.Tensor: the average precision of all classes. + """ + average_options = ['macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target) + assert pred.ndim == 2 and pred.shape == target.shape, \ + 'Both `pred` and `target` should have shape `(N, num_classes)`.' + + num_classes = pred.shape[1] + ap = pred.new_zeros(num_classes) + for k in range(num_classes): + ap[k] = _average_precision(pred[:, k], target[:, k]) + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 diff --git a/mmpretrain/evaluation/metrics/multi_task.py b/mmpretrain/evaluation/metrics/multi_task.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6af7680192883308df5f24b65ec38c9bb65ce6 --- /dev/null +++ b/mmpretrain/evaluation/metrics/multi_task.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class MultiTasksMetric(BaseMetric): + """Metrics for MultiTask + Args: + task_metrics(dict): a dictionary in the keys are the names of the tasks + and the values is a list of the metric corresponds to this task + Examples: + >>> import torch + >>> from mmpretrain.evaluation import MultiTasksMetric + # -------------------- The Basic Usage -------------------- + >>>task_metrics = { + 'task0': [dict(type='Accuracy', topk=(1, ))], + 'task1': [dict(type='Accuracy', topk=(1, 3))] + } + >>>pred = [{ + 'pred_task': { + 'task0': torch.tensor([0.7, 0.0, 0.3]), + 'task1': torch.tensor([0.5, 0.2, 0.3]) + }, + 'gt_task': { + 'task0': torch.tensor(0), + 'task1': torch.tensor(2) + } + }, { + 'pred_task': { + 'task0': torch.tensor([0.0, 0.0, 1.0]), + 'task1': torch.tensor([0.0, 0.0, 1.0]) + }, + 'gt_task': { + 'task0': torch.tensor(2), + 'task1': torch.tensor(2) + } + }] + >>>metric = MultiTasksMetric(task_metrics) + >>>metric.process(None, pred) + >>>results = metric.evaluate(2) + results = { + 'task0_accuracy/top1': 100.0, + 'task1_accuracy/top1': 50.0, + 'task1_accuracy/top3': 100.0 + } + """ + + def __init__(self, + task_metrics: Dict, + collect_device: str = 'cpu') -> None: + self.task_metrics = task_metrics + super().__init__(collect_device=collect_device) + + self._metrics = {} + for task_name in self.task_metrics.keys(): + self._metrics[task_name] = [] + for metric in self.task_metrics[task_name]: + self._metrics[task_name].append(METRICS.build(metric)) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for task_name in self.task_metrics.keys(): + filtered_data_samples = [] + for data_sample in data_samples: + eval_mask = data_sample[task_name]['eval_mask'] + if eval_mask: + filtered_data_samples.append(data_sample[task_name]) + for metric in self._metrics[task_name]: + metric.process(data_batch, filtered_data_samples) + + def compute_metrics(self, results: list) -> dict: + raise NotImplementedError( + 'compute metrics should not be used here directly') + + def evaluate(self, size): + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are + "{task_name}_{metric_name}" , and the values + are corresponding results. + """ + metrics = {} + for task_name in self._metrics: + for metric in self._metrics[task_name]: + name = metric.__class__.__name__ + if name == 'MultiTasksMetric' or metric.results: + results = metric.evaluate(size) + else: + results = {metric.__class__.__name__: 0} + for key in results: + name = f'{task_name}_{key}' + if name in results: + """Inspired from https://github.com/open- + mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 + 747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" + raise ValueError( + 'There are multiple metric results with the same' + f'metric name {name}. Please make sure all metrics' + 'have different prefixes.') + metrics[name] = results[key] + return metrics diff --git a/mmpretrain/evaluation/metrics/nocaps.py b/mmpretrain/evaluation/metrics/nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e1d0625b66dfa1abe59bd6f83ea2a6c0b3d446 --- /dev/null +++ b/mmpretrain/evaluation/metrics/nocaps.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import mmengine + +from mmpretrain.registry import METRICS +from mmpretrain.utils import require +from .caption import COCOCaption, save_result + +try: + from pycocoevalcap.eval import COCOEvalCap + from pycocotools.coco import COCO +except ImportError: + COCOEvalCap = None + COCO = None + + +@METRICS.register_module() +class NocapsSave(COCOCaption): + """Nocaps evaluation wrapper. + + Save the generated captions and transform into coco format. + The dumped file can be submitted to the official evluation system. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + @require('pycocoevalcap') + def __init__(self, + save_dir: str = './', + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super(COCOCaption, self).__init__( + collect_device=collect_device, prefix=prefix) + self.save_dir = save_dir + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + """ + mmengine.mkdir_or_exist(self.save_dir) + save_result( + result=results, + result_dir=self.save_dir, + filename='nocap_pred', + remove_duplicate='image_id', + ) + + return dict() diff --git a/mmpretrain/evaluation/metrics/retrieval.py b/mmpretrain/evaluation/metrics/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..9813486b521c5b73d7be96901ea4f604bbe2a938 --- /dev/null +++ b/mmpretrain/evaluation/metrics/retrieval.py @@ -0,0 +1,445 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.utils import is_seq_of + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .single_label import to_tensor + + +@METRICS.register_module() +class RetrievalRecall(BaseMetric): + r"""Recall evaluation metric for image retrieval. + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k recall will + be calculated and outputted together. Defaults to 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + Use in the code: + + >>> import torch + >>> from mmpretrain.evaluation import RetrievalRecall + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [[0], [1], [2], [3]] + >>> y_true = [[0, 1], [2], [1], [0, 3]] + >>> RetrievalRecall.calculate( + >>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + [tensor([50.])] + >>> # Calculate the recall@1 and recall@5 for non-indices input. + >>> y_score = torch.rand((1000, 10)) + >>> import torch.nn.functional as F + >>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10) + >>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5)) + [tensor(9.3000), tensor(48.4000)] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label([0, 1]).set_pred_score( + ... torch.rand(10)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5))) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + {'retrieval/Recall@1': 20.700000762939453, + 'retrieval/Recall@5': 78.5999984741211} + + Use in OpenMMLab configs: + + .. code:: python + + val_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) + test_evaluator = val_evaluator + """ + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Union[int, Sequence[int]], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + topk = (topk, ) if isinstance(topk, int) else topk + + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + self.topk = topk + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample['pred_score'].clone() + gt_label = data_sample['gt_label'] + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalRecall.calculate( + pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + for i, k in enumerate(self.topk): + recall_at_k = sum([r[i].item() for r in results]) / len(results) + result_metrics[f'Recall@{k}'] = recall_at_k + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Union[int, Sequence[int]], + pred_indices: (bool) = False, + target_indices: (bool) = False) -> float: + """Calculate the average recall. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, Sequence[int]): Predictions with the k-th highest + scores are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + + Returns: + List[float]: the average recalls. + """ + topk = (topk, ) if isinstance(topk, int) else topk + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + max_keep = max(topk) + pred = _format_pred(pred, max_keep, pred_indices) + target = _format_target(target, target_indices) + + assert len(pred) == len(target), ( + f'Length of `pred`({len(pred)}) and `target` ({len(target)}) ' + f'must be the same.') + + num_samples = len(pred) + results = [] + for k in topk: + recalls = torch.zeros(num_samples) + for i, (sample_pred, + sample_target) in enumerate(zip(pred, target)): + sample_pred = np.array(to_tensor(sample_pred).cpu()) + sample_target = np.array(to_tensor(sample_target).cpu()) + recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max()) + results.append(recalls.mean() * 100) + return results + + +@METRICS.register_module() +class RetrievalAveragePrecision(BaseMetric): + r"""Calculate the average precision for image retrieval. + + Args: + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. + mode (str, optional): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page[1]; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets[2]. + + References: + [1] `Wikipedia entry for the Average precision `_ + + [2] `The Oxford Buildings Dataset + `_ + + Examples: + Use in code: + + >>> import torch + >>> import numpy as np + >>> from mmcls.evaluation import RetrievalAveragePrecision + >>> # using index format inputs + >>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3 + >>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]] + >>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True) + 29.246031746031747 + >>> # using tensor format inputs + >>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + >>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2) + >>> RetrievalAveragePrecision.calculate(pred, target, 10) + 62.222222222222214 + + Use in OpenMMLab config files: + + .. code:: python + + val_evaluator = dict(type='RetrievalAveragePrecision', topk=100) + test_evaluator = val_evaluator + """ + + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Optional[int] = None, + mode: Optional[str] = 'IR', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + if topk is None or (isinstance(topk, int) and topk <= 0): + raise ValueError('`topk` must be a ingter larger than 0.') + + mode_options = ['IR', 'integrate'] + assert mode in mode_options, \ + f'Invalid `mode` argument, please specify from {mode_options}.' + + self.topk = topk + self.mode = mode + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample.get('pred_score').clone() + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + gt_label = data_sample.get('gt_label') + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalAveragePrecision.calculate( + pred_score.unsqueeze(0), + target.unsqueeze(0), + self.topk, + mode=self.mode) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Optional[int] = None, + pred_indices: (bool) = False, + target_indices: (bool) = False, + mode: str = 'IR') -> float: + """Calculate the average precision. + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, optional): Predictions with the k-th highest scores + are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + mode (Optional[str]): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets. + + Returns: + float: the average precision of the query image. + + References: + [1] `Wikipedia entry for Average precision(information_retrieval) + `_ + [2] `The Oxford Buildings Dataset 0 else 1 + cur_precision = (i + 1) / (rank + 1) + prediction = (old_precision + cur_precision) / 2 + ap += prediction + ap = ap / len(target) + + return ap * 100 + + +def _format_pred(label, topk=None, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`pred` must be Sequence of indices when' \ + f' `pred_indices` set to True, but get {type(label)}' + for i, sample_pred in enumerate(label): + assert is_seq_of(sample_pred, int) or isinstance( + sample_pred, (np.ndarray, torch.Tensor)), \ + '`pred` should be Sequence of indices when `pred_indices`' \ + f'set to True. but pred[{i}] is {sample_pred}' + if topk: + label[i] = sample_pred[:min(topk, len(sample_pred))] + return label + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + topk = topk if topk else label.size()[-1] + _, indices = label.topk(topk) + return indices + + +def _format_target(label, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`target` must be Sequence of indices when' \ + f' `target_indices` set to True, but get {type(label)}' + for i, sample_gt in enumerate(label): + assert is_seq_of(sample_gt, int) or isinstance( + sample_gt, (np.ndarray, torch.Tensor)), \ + '`target` should be Sequence of indices when ' \ + f'`target_indices` set to True. but target[{i}] is {sample_gt}' + return label + + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif isinstance(label, Sequence) and not mmengine.is_str(label): + label = torch.tensor(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + + indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label] + return indices diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf01c78cc88e5ce5e232fe837a0d77293386112 --- /dev/null +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def get_pred_idx(prediction: str, choices: List[str], + options: List[str]) -> int: # noqa + """Get the index (e.g. 2) from the prediction (e.g. 'C') + + Args: + prediction (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + options (List(str)): The options for the question, + from ['A', 'B', 'C', 'D', 'E'] + + Returns: + int: The index of the prediction, from [0, 1, 2, 3, 4] + """ + if prediction in options[:len(choices)]: + return options.index(prediction) + else: + return random.choice(range(len(choices))) + + +@METRICS.register_module() +class ScienceQAMetric(BaseMetric): + """Evaluation Metric for ScienceQA. + + Args: + options (List(str)): Options for each question. Defaults to + ["A", "B", "C", "D", "E"]. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + def __init__(self, + options: List[str] = ['A', 'B', 'C', 'D', 'E'], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.options = options + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + data_samples should contain the following keys: + 1. pred_answer (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + 2. choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + 3. grade (int): The grade for the question, from grade1 to grade12 + 4. subject (str): The subject for the question, from + ['natural science', 'social science', 'language science'] + 5. answer (str): The answer for the question, from + ['A', 'B', 'C', 'D', 'E'] + 6. hint (str): The hint for the question + 7. has_image (bool): Whether or not the question has image + + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + choices = data_sample.get('choices') + result['prediction'] = get_pred_idx( + data_sample.get('pred_answer'), choices, self.options) + result['grade'] = data_sample.get('grade') + result['subject'] = data_sample.get('subject') + result['answer'] = data_sample.get('gt_answer') + hint = data_sample.get('hint') + has_image = data_sample.get('has_image', False) + result['no_context'] = True if not has_image and len( + hint) == 0 else False # noqa + result['has_text'] = True if len(hint) > 0 else False + result['has_image'] = has_image + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = dict() + + all_acc = [] + acc_natural = [] + acc_social = [] + acc_language = [] + acc_has_text = [] + acc_has_image = [] + acc_no_context = [] + acc_grade_1_6 = [] + acc_grade_7_12 = [] + + for result in results: + correct = result['prediction'] == result['answer'] + all_acc.append(correct) + # different subjects + if result['subject'] == 'natural science': + acc_natural.append(correct) + elif result['subject'] == 'social science': + acc_social.append(correct) + elif result['subject'] == 'language science': + acc_language.append(correct) + + # different context + if result['has_text']: + acc_has_text.append(correct) + elif result['has_image']: + acc_has_image.append(correct) + elif result['no_context']: + acc_no_context.append(correct) + + # different grade + if result['grade'] in [ + 'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6' + ]: + acc_grade_1_6.append(correct) + elif result['grade'] in [ + 'grade7', 'grade8', 'grade9', 'grade10', 'grade11', + 'grade12' + ]: + acc_grade_7_12.append(correct) + + metrics['all_acc'] = sum(all_acc) / len(all_acc) + if len(acc_natural) > 0: + metrics['acc_natural'] = sum(acc_natural) / len(acc_natural) + if len(acc_social) > 0: + metrics['acc_social'] = sum(acc_social) / len(acc_social) + if len(acc_language) > 0: + metrics['acc_language'] = sum(acc_language) / len(acc_language) + if len(acc_has_text) > 0: + metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text) + if len(acc_has_image) > 0: + metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image) + if len(acc_no_context) > 0: + metrics['acc_no_context'] = sum(acc_no_context) / len( + acc_no_context) + if len(acc_grade_1_6) > 0: + metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6) + if len(acc_grade_7_12) > 0: + metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len( + acc_grade_7_12) + + return metrics diff --git a/mmpretrain/evaluation/metrics/shape_bias_label.py b/mmpretrain/evaluation/metrics/shape_bias_label.py new file mode 100644 index 0000000000000000000000000000000000000000..27c80a36073a9e6edd5e6583e213ed93374b165e --- /dev/null +++ b/mmpretrain/evaluation/metrics/shape_bias_label.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import os +import os.path as osp +from typing import List, Sequence + +import numpy as np +import torch +from mmengine.dist.utils import get_rank +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class ShapeBiasMetric(BaseMetric): + """Evaluate the model on ``cue_conflict`` dataset. + + This module will evaluate the model on an OOD dataset, cue_conflict, in + order to measure the shape bias of the model. In addition to compuate the + Top-1 accuracy, this module also generate a csv file to record the + detailed prediction results, such that this csv file can be used to + generate the shape bias curve. + + Args: + csv_dir (str): The directory to save the csv file. + model_name (str): The name of the csv file. Please note that the + model name should be an unique identifier. + dataset_name (str): The name of the dataset. Default: 'cue_conflict'. + """ + + # mapping several classes from ImageNet-1K to the same category + airplane_indices = [404] + bear_indices = [294, 295, 296, 297] + bicycle_indices = [444, 671] + bird_indices = [ + 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129, + 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, + 145 + ] + boat_indices = [472, 554, 625, 814, 914] + bottle_indices = [440, 720, 737, 898, 899, 901, 907] + car_indices = [436, 511, 817] + cat_indices = [281, 282, 283, 284, 285, 286] + chair_indices = [423, 559, 765, 857] + clock_indices = [409, 530, 892] + dog_indices = [ + 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194, + 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, + 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, + 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254, + 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268 + ] + elephant_indices = [385, 386] + keyboard_indices = [508, 878] + knife_indices = [499] + oven_indices = [766] + truck_indices = [555, 569, 656, 675, 717, 734, 864, 867] + + def __init__(self, + csv_dir: str, + model_name: str, + dataset_name: str = 'cue_conflict', + **kwargs) -> None: + super().__init__(**kwargs) + + self.categories = sorted([ + 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock', + 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car', + 'bird', 'dog' + ]) + self.csv_dir = csv_dir + self.model_name = model_name + self.dataset_name = dataset_name + if get_rank() == 0: + self.csv_path = self.create_csv() + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + result['gt_category'] = data_sample['img_path'].split('/')[-2] + result['img_name'] = data_sample['img_path'].split('/')[-1] + + aggregated_category_probabilities = [] + # get the prediction for each category of current instance + for category in self.categories: + category_indices = getattr(self, f'{category}_indices') + category_probabilities = torch.gather( + result['pred_score'], 0, + torch.tensor(category_indices)).mean() + aggregated_category_probabilities.append( + category_probabilities) + # sort the probabilities in descending order + pred_indices = torch.stack(aggregated_category_probabilities + ).argsort(descending=True).numpy() + result['pred_category'] = np.take(self.categories, pred_indices) + + # Save the result to `self.results`. + self.results.append(result) + + def create_csv(self) -> str: + """Create a csv file to store the results.""" + session_name = 'session-1' + csv_path = osp.join( + self.csv_dir, self.dataset_name + '_' + self.model_name + '_' + + session_name + '.csv') + if osp.exists(csv_path): + os.remove(csv_path) + directory = osp.dirname(csv_path) + if not osp.exists(directory): + os.makedirs(directory, exist_ok=True) + with open(csv_path, 'w') as f: + writer = csv.writer(f) + writer.writerow([ + 'subj', 'session', 'trial', 'rt', 'object_response', + 'category', 'condition', 'imagename' + ]) + return csv_path + + def dump_results_to_csv(self, results: List[dict]) -> None: + """Dump the results to a csv file. + + Args: + results (List[dict]): A list of results. + """ + for i, result in enumerate(results): + img_name = result['img_name'] + category = result['gt_category'] + condition = 'NaN' + with open(self.csv_path, 'a') as f: + writer = csv.writer(f) + writer.writerow([ + self.model_name, 1, i + 1, 'NaN', + result['pred_category'][0], category, condition, img_name + ]) + + def compute_metrics(self, results: List[dict]) -> dict: + """Compute the metrics from the results. + + Args: + results (List[dict]): A list of results. + + Returns: + dict: A dict of metrics. + """ + if get_rank() == 0: + self.dump_results_to_csv(results) + metrics = dict() + metrics['accuracy/top1'] = np.mean([ + result['pred_category'][0] == result['gt_category'] + for result in results + ]) + + return metrics diff --git a/mmpretrain/evaluation/metrics/single_label.py b/mmpretrain/evaluation/metrics/single_label.py new file mode 100644 index 0000000000000000000000000000000000000000..f9329b9567e698a4e3ebdb7d77f0f8404b81ad4c --- /dev/null +++ b/mmpretrain/evaluation/metrics/single_label.py @@ -0,0 +1,776 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def to_tensor(value): + """Convert value to torch.Tensor.""" + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif isinstance(value, Sequence) and not mmengine.is_str(value): + value = torch.tensor(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'{type(value)} is not an available argument.') + return value + + +def _precision_recall_f1_support(pred_positive, gt_positive, average): + """calculate base classification task metrics, such as precision, recall, + f1_score, support.""" + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + + # ignore -1 target such as difficult sample that is not wanted + # in evaluation results. + # only for calculate multi-label without affecting single-label behavior + ignored_index = gt_positive == -1 + pred_positive[ignored_index] = 0 + gt_positive[ignored_index] = 0 + + class_correct = (pred_positive & gt_positive) + if average == 'micro': + tp_sum = class_correct.sum() + pred_sum = pred_positive.sum() + gt_sum = gt_positive.sum() + else: + tp_sum = class_correct.sum(0) + pred_sum = pred_positive.sum(0) + gt_sum = gt_positive.sum(0) + + precision = tp_sum / torch.clamp(pred_sum, min=1).float() * 100 + recall = tp_sum / torch.clamp(gt_sum, min=1).float() * 100 + f1_score = 2 * precision * recall / torch.clamp( + precision + recall, min=torch.finfo(torch.float32).eps) + if average in ['macro', 'micro']: + precision = precision.mean(0) + recall = recall.mean(0) + f1_score = f1_score.mean(0) + support = gt_sum.sum(0) + else: + support = gt_sum + return precision, recall, f1_score, support + + +@METRICS.register_module() +class Accuracy(BaseMetric): + r"""Accuracy evaluation metric. + + For either binary classification or multi-class classification, the + accuracy is the fraction of correct predictions in all predictions: + + .. math:: + + \text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}} + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k accuracy will + be calculated and outputted together. Defaults to 1. + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, not apply threshold. If the parameter is a + tuple, accuracy based on all thresholds will be calculated and + outputted together. Defaults to 0. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import Accuracy + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [0, 2, 1, 3] + >>> y_true = [0, 1, 2, 3] + >>> Accuracy.calculate(y_pred, y_true) + tensor([50.]) + >>> # Calculate the top1 and top5 accuracy. + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.zeros((1000, )) + >>> Accuracy.calculate(y_score, y_true, topk=(1, 5)) + [[tensor([9.9000])], [tensor([51.5000])]] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label(0).set_pred_score(torch.rand(10)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5))) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + { + 'accuracy/top1': 9.300000190734863, + 'accuracy/top5': 51.20000076293945 + } + """ + default_prefix: Optional[str] = 'accuracy' + + def __init__(self, + topk: Union[int, Sequence[int]] = (1, ), + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(topk, int): + self.topk = (topk, ) + else: + self.topk = tuple(topk) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = {} + + # concat + target = torch.cat([res['gt_label'] for res in results]) + if 'pred_score' in results[0]: + pred = torch.stack([res['pred_score'] for res in results]) + + try: + acc = self.calculate(pred, target, self.topk, self.thrs) + except ValueError as e: + # If the topk is invalid. + raise ValueError( + str(e) + ' Please check the `val_evaluator` and ' + '`test_evaluator` fields in your config file.') + + multi_thrs = len(self.thrs) > 1 + for i, k in enumerate(self.topk): + for j, thr in enumerate(self.thrs): + name = f'top{k}' + if multi_thrs: + name += '_no-thr' if thr is None else f'_thr-{thr:.2f}' + metrics[name] = acc[i][j].item() + else: + # If only label in the `pred_label`. + pred = torch.cat([res['pred_label'] for res in results]) + acc = self.calculate(pred, target, self.topk, self.thrs) + metrics['top1'] = acc.item() + + return metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + topk: Sequence[int] = (1, ), + thrs: Sequence[Union[float, None]] = (0., ), + ) -> Union[torch.Tensor, List[List[torch.Tensor]]]: + """Calculate the accuracy. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + thrs (Sequence[float | None]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. None means no thresholds. + Defaults to (0., ). + thrs (Sequence[float]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. Defaults to (0., ). + + Returns: + torch.Tensor | List[List[torch.Tensor]]: Accuracy. + + - torch.Tensor: If the ``pred`` is a sequence of label instead of + score (number of dimensions is 1). Only return a top-1 accuracy + tensor, and ignore the argument ``topk` and ``thrs``. + - List[List[torch.Tensor]]: If the ``pred`` is a sequence of score + (number of dimensions is 2). Return the accuracy on each ``topk`` + and ``thrs``. And the first dim is ``topk``, the second dim is + ``thrs``. + """ + + pred = to_tensor(pred) + target = to_tensor(target).to(torch.int64) + num = pred.size(0) + assert pred.size(0) == target.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target.size(0)}).' + + if pred.ndim == 1: + # For pred label, ignore topk and acc + pred_label = pred.int() + correct = pred.eq(target).float().sum(0, keepdim=True) + acc = correct.mul_(100. / num) + return acc + else: + # For pred score, calculate on all topk and thresholds. + pred = pred.float() + maxk = max(topk) + + if maxk > pred.size(1): + raise ValueError( + f'Top-{maxk} accuracy is unavailable since the number of ' + f'categories is {pred.size(1)}.') + + pred_score, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + results = [] + for k in topk: + results.append([]) + for thr in thrs: + # Only prediction values larger than thr are counted + # as correct + _correct = correct + if thr is not None: + _correct = _correct & (pred_score.t() > thr) + correct_k = _correct[:k].reshape(-1).float().sum( + 0, keepdim=True) + acc = correct_k.mul_(100. / num) + results[-1].append(acc) + return results + + +@METRICS.register_module() +class SingleLabelMetric(BaseMetric): + r"""A collection of precision, recall, f1-score and support for + single-label tasks. + + The collection of metrics is for single-label multi-class classification. + And all these metrics are based on the confusion matrix of every category: + + .. image:: ../../_static/image/confusion-matrix.png + :width: 60% + :align: center + + All metrics can be formulated use variables above: + + **Precision** is the fraction of correct predictions in all predictions: + + .. math:: + \text{Precision} = \frac{TP}{TP+FP} + + **Recall** is the fraction of correct predictions in all targets: + + .. math:: + \text{Recall} = \frac{TP}{TP+FN} + + **F1-score** is the harmonic mean of the precision and recall: + + .. math:: + \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}} + + **Support** is the number of samples: + + .. math:: + \text{Support} = TP + TN + FN + FP + + Args: + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, only the top-1 prediction will be regard as + the positive prediction. If the parameter is a tuple, accuracy + based on all thresholds will be calculated and outputted together. + Defaults to 0. + items (Sequence[str]): The detailed metric items to evaluate, select + from "precision", "recall", "f1-score" and "support". + Defaults to ``('precision', 'recall', 'f1-score')``. + average (str | None): How to calculate the final metrics from the + confusion matrix of every category. It supports three modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories and + calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import SingleLabelMetric + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> # Output precision, recall, f1-score and support. + >>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4) + (tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4)) + >>> # Calculate with different thresholds. + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.zeros((1000, )) + >>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9)) + [(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)), + (tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=SingleLabelMetric()) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + {'single-label/precision': 19.650691986083984, + 'single-label/recall': 19.600000381469727, + 'single-label/f1-score': 19.619548797607422} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None)) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + { + 'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1], + 'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0], + 'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0] + } + """ # noqa: E501 + default_prefix: Optional[str] = 'single-label' + + def __init__(self, + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please specify from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + self.average = average + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + num_classes = self.num_classes or data_sample.get( + 'num_classes') + assert num_classes is not None, \ + 'The `num_classes` must be specified if no `pred_score`.' + result['pred_label'] = data_sample['pred_label'].cpu() + result['num_classes'] = num_classes + result['gt_label'] = data_sample['gt_label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + # concat + target = torch.cat([res['gt_label'] for res in results]) + if 'pred_score' in results[0]: + pred = torch.stack([res['pred_score'] for res in results]) + metrics_list = self.calculate( + pred, target, thrs=self.thrs, average=self.average) + + multi_thrs = len(self.thrs) > 1 + for i, thr in enumerate(self.thrs): + if multi_thrs: + suffix = '_no-thr' if thr is None else f'_thr-{thr:.2f}' + else: + suffix = '' + + for k, v in pack_results(*metrics_list[i]).items(): + metrics[k + suffix] = v + else: + # If only label in the `pred_label`. + pred = torch.cat([res['pred_label'] for res in results]) + res = self.calculate( + pred, + target, + average=self.average, + num_classes=results[0]['num_classes']) + metrics = pack_results(*res) + + result_metrics = dict() + for k, v in metrics.items(): + + if self.average is None: + result_metrics[k + '_classwise'] = v.cpu().detach().tolist() + elif self.average == 'micro': + result_metrics[k + f'_{self.average}'] = v.item() + else: + result_metrics[k] = v.item() + + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + thrs: Sequence[Union[float, None]] = (0., ), + average: Optional[str] = 'macro', + num_classes: Optional[int] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score and support. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + thrs (Sequence[float | None]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. None means no thresholds. + Defaults to (0., ). + average (str | None): How to calculate the final metrics from + the confusion matrix of every category. It supports three + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories + and calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: If the ``pred`` is a sequence of label instead of + score (number of dimensions is 1). Only returns a tensor for + each metric. The shape is (1, ) if ``classwise`` is False, and + (C, ) if ``classwise`` is True. + - List[torch.Tensor]: If the ``pred`` is a sequence of score + (number of dimensions is 2). Return the metrics on each ``thrs``. + The shape of tensor is (1, ) if ``classwise`` is False, and (C, ) + if ``classwise`` is True. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target).to(torch.int64) + assert pred.size(0) == target.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target.size(0)}).' + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + gt_positive = F.one_hot(target.flatten(), num_classes) + pred_positive = F.one_hot(pred.to(torch.int64), num_classes) + return _precision_recall_f1_support(pred_positive, gt_positive, + average) + else: + # For pred score, calculate on all thresholds. + num_classes = pred.size(1) + pred_score, pred_label = torch.topk(pred, k=1) + pred_score = pred_score.flatten() + pred_label = pred_label.flatten() + + gt_positive = F.one_hot(target.flatten(), num_classes) + + results = [] + for thr in thrs: + pred_positive = F.one_hot(pred_label, num_classes) + if thr is not None: + pred_positive[pred_score <= thr] = 0 + results.append( + _precision_recall_f1_support(pred_positive, gt_positive, + average)) + + return results + + +@METRICS.register_module() +class ConfusionMatrix(BaseMetric): + r"""A metric to calculate confusion matrix for single-label tasks. + + Args: + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + + 1. The basic usage. + + >>> import torch + >>> from mmpretrain.evaluation import ConfusionMatrix + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4) + tensor([[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) + >>> # plot the confusion matrix + >>> import matplotlib.pyplot as plt + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.randint(10, (1000, )) + >>> matrix = ConfusionMatrix.calculate(y_score, y_true) + >>> ConfusionMatrix().plot(matrix) + >>> plt.show() + + 2. In the config file + + .. code:: python + + val_evaluator = dict(type='ConfusionMatrix') + test_evaluator = dict(type='ConfusionMatrix') + """ # noqa: E501 + default_prefix = 'confusion_matrix' + + def __init__(self, + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + for data_sample in data_samples: + if 'pred_score' in data_sample: + pred_score = data_sample['pred_score'] + pred_label = pred_score.argmax(dim=0, keepdim=True) + self.num_classes = pred_score.size(0) + else: + pred_label = data_sample['pred_label'] + + self.results.append({ + 'pred_label': pred_label, + 'gt_label': data_sample['gt_label'], + }) + + def compute_metrics(self, results: list) -> dict: + pred_labels = [] + gt_labels = [] + for result in results: + pred_labels.append(result['pred_label']) + gt_labels.append(result['gt_label']) + confusion_matrix = ConfusionMatrix.calculate( + torch.cat(pred_labels), + torch.cat(gt_labels), + num_classes=self.num_classes) + return {'result': confusion_matrix} + + @staticmethod + def calculate(pred, target, num_classes=None) -> dict: + """Calculate the confusion matrix for single-label task. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + torch.Tensor: The confusion matrix. + """ + pred = to_tensor(pred) + target_label = to_tensor(target).int() + + assert pred.size(0) == target_label.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target_label.size(0)}).' + assert target_label.ndim == 1 + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + pred_label = pred + else: + num_classes = num_classes or pred.size(1) + pred_label = torch.argmax(pred, dim=1).flatten() + + with torch.no_grad(): + indices = num_classes * target_label + pred_label + matrix = torch.bincount(indices, minlength=num_classes**2) + matrix = matrix.reshape(num_classes, num_classes) + + return matrix + + @staticmethod + def plot(confusion_matrix: torch.Tensor, + include_values: bool = False, + cmap: str = 'viridis', + classes: Optional[List[str]] = None, + colorbar: bool = True, + show: bool = True): + """Draw a confusion matrix by matplotlib. + + Modified from `Scikit-Learn + `_ + + Args: + confusion_matrix (torch.Tensor): The confusion matrix to draw. + include_values (bool): Whether to draw the values in the figure. + Defaults to False. + cmap (str): The color map to use. Defaults to use "viridis". + classes (list[str], optional): The names of categories. + Defaults to None, which means to use index number. + colorbar (bool): Whether to show the colorbar. Defaults to True. + show (bool): Whether to show the figure immediately. + Defaults to True. + """ # noqa: E501 + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 10)) + + num_classes = confusion_matrix.size(0) + + im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap) + text_ = None + cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0) + + if include_values: + text_ = np.empty_like(confusion_matrix, dtype=object) + + # print text with appropriate color depending on background + thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0 + + for i, j in product(range(num_classes), range(num_classes)): + color = cmap_max if confusion_matrix[i, + j] < thresh else cmap_min + + text_cm = format(confusion_matrix[i, j], '.2g') + text_d = format(confusion_matrix[i, j], 'd') + if len(text_d) < len(text_cm): + text_cm = text_d + + text_[i, j] = ax.text( + j, i, text_cm, ha='center', va='center', color=color) + + display_labels = classes or np.arange(num_classes) + + if colorbar: + fig.colorbar(im_, ax=ax) + ax.set( + xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=display_labels, + yticklabels=display_labels, + ylabel='True label', + xlabel='Predicted label', + ) + ax.invert_yaxis() + ax.xaxis.tick_top() + + ax.set_ylim((num_classes - 0.5, -0.5)) + # Automatically rotate the x labels. + fig.autofmt_xdate(ha='center') + + if show: + plt.show() + return fig diff --git a/mmpretrain/evaluation/metrics/visual_grounding_eval.py b/mmpretrain/evaluation/metrics/visual_grounding_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ad16e5adf4660496b3a984087294ed9c0fee6537 --- /dev/null +++ b/mmpretrain/evaluation/metrics/visual_grounding_eval.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torchvision.ops.boxes as boxes +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def aligned_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor): + area1 = boxes.box_area(boxes1) + area2 = boxes.box_area(boxes2) + + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (B, 2) + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (B, 2) + + wh = boxes._upcast(rb - lt).clamp(min=0) # (B, 2) + inter = wh[:, 0] * wh[:, 1] # (B, ) + + union = area1 + area2 - inter + iou = inter / union + return iou + + +@METRICS.register_module() +class VisualGroundingMetric(BaseMetric): + """Visual Grounding evaluator. + + Calculate the box mIOU and box grounding accuracy for visual grounding + model. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'visual-grounding' + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for preds in data_samples: + + pred_box = preds['pred_bboxes'].squeeze() + box_gt = torch.Tensor(preds['gt_bboxes']).squeeze() + + result = { + 'box': pred_box.to('cpu').squeeze(), + 'box_target': box_gt.squeeze(), + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + pred_boxes = torch.stack([each['box'] for each in results]) + gt_boxes = torch.stack([each['box_target'] for each in results]) + iou = aligned_box_iou(pred_boxes, gt_boxes) + accu_num = torch.sum(iou >= 0.5) + + miou = torch.mean(iou) + acc = accu_num / len(gt_boxes) + coco_val = {'miou': miou, 'acc': acc} + return coco_val diff --git a/mmpretrain/evaluation/metrics/voc_multi_label.py b/mmpretrain/evaluation/metrics/voc_multi_label.py new file mode 100644 index 0000000000000000000000000000000000000000..1034852722796271c7ade9d75c3442cce8f1d0d1 --- /dev/null +++ b/mmpretrain/evaluation/metrics/voc_multi_label.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .multi_label import AveragePrecision, MultiLabelMetric + + +class VOCMetricMixin: + """A mixin class for VOC dataset metrics, VOC annotations have extra + `difficult` attribute for each object, therefore, extra option is needed + for calculating VOC metrics. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + """ + + def __init__(self, + *arg, + difficult_as_positive: Optional[bool] = None, + **kwarg): + self.difficult_as_positive = difficult_as_positive + super().__init__(*arg, **kwarg) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + gt_label = data_sample['gt_label'] + gt_label_difficult = data_sample['gt_label_difficult'] + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(gt_label, num_classes) + + # VOC annotation labels all the objects in a single image + # therefore, some categories are appeared both in + # difficult objects and non-difficult objects. + # Here we reckon those labels which are only exists in difficult + # objects as difficult labels. + difficult_label = set(gt_label_difficult) - ( + set(gt_label_difficult) & set(gt_label.tolist())) + + # set difficult label for better eval + if self.difficult_as_positive is None: + result['gt_score'][[*difficult_label]] = -1 + elif self.difficult_as_positive: + result['gt_score'][[*difficult_label]] = 1 + + # Save the result to `self.results`. + self.results.append(result) + + +@METRICS.register_module() +class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric): + """A collection of metrics for multi-label multi-class classification task + based on confusion matrix for VOC dataset. + + It includes precision, recall, f1-score and support. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + **kwarg: Refers to `MultiLabelMetric` for detailed docstrings. + """ + + +@METRICS.register_module() +class VOCAveragePrecision(VOCMetricMixin, AveragePrecision): + """Calculate the average precision with respect of classes for VOC dataset. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + **kwarg: Refers to `AveragePrecision` for detailed docstrings. + """ diff --git a/mmpretrain/evaluation/metrics/vqa.py b/mmpretrain/evaluation/metrics/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..fd77ba9bc23e013c41ac095810740bdb71d33fb3 --- /dev/null +++ b/mmpretrain/evaluation/metrics/vqa.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Partly adopted from https://github.com/GT-Vision-Lab/VQA +# Copyright (c) 2014, Aishwarya Agrawal +from typing import List, Optional + +import mmengine +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS + + +def _process_punctuation(inText): + import re + outText = inText + punct = [ + ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!' + ] + commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 + periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 + for p in punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search( + commaStrip, inText) is not None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = periodStrip.sub('', outText, re.UNICODE) + return outText + + +def _process_digit_article(inText): + outText = [] + tempText = inText.lower().split() + articles = ['a', 'an', 'the'] + manualMap = { + 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10', + } + contractions = { + 'aint': "ain't", + 'arent': "aren't", + 'cant': "can't", + 'couldve': "could've", + 'couldnt': "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + 'didnt': "didn't", + 'doesnt': "doesn't", + 'dont': "don't", + 'hadnt': "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + 'hasnt': "hasn't", + 'havent': "haven't", + 'hed': "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + 'hes': "he's", + 'howd': "how'd", + 'howll': "how'll", + 'hows': "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + 'Im': "I'm", + 'Ive': "I've", + 'isnt': "isn't", + 'itd': "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + 'itll': "it'll", + "let's": "let's", + 'maam': "ma'am", + 'mightnt': "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + 'mightve': "might've", + 'mustnt': "mustn't", + 'mustve': "must've", + 'neednt': "needn't", + 'notve': "not've", + 'oclock': "o'clock", + 'oughtnt': "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + 'shant': "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + 'shouldve': "should've", + 'shouldnt': "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": 'somebodyd', + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + 'somebodyll': "somebody'll", + 'somebodys': "somebody's", + 'someoned': "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + 'someonell': "someone'll", + 'someones': "someone's", + 'somethingd': "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + 'somethingll': "something'll", + 'thats': "that's", + 'thered': "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + 'therere': "there're", + 'theres': "there's", + 'theyd': "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + 'theyll': "they'll", + 'theyre': "they're", + 'theyve': "they've", + 'twas': "'twas", + 'wasnt': "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + 'weve': "we've", + 'werent': "weren't", + 'whatll': "what'll", + 'whatre': "what're", + 'whats': "what's", + 'whatve': "what've", + 'whens': "when's", + 'whered': "where'd", + 'wheres': "where's", + 'whereve': "where've", + 'whod': "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + 'wholl': "who'll", + 'whos': "who's", + 'whove': "who've", + 'whyll': "why'll", + 'whyre': "why're", + 'whys': "why's", + 'wont': "won't", + 'wouldve': "would've", + 'wouldnt': "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + 'yall': "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + 'youd': "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + 'youll': "you'll", + 'youre': "you're", + 'youve': "you've", + } + for word in tempText: + word = manualMap.setdefault(word, word) + if word not in articles: + outText.append(word) + for wordId, word in enumerate(outText): + if word in contractions: + outText[wordId] = contractions[word] + outText = ' '.join(outText) + return outText + + +@METRICS.register_module() +class VQAAcc(BaseMetric): + '''VQA Acc metric. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'VQA' + + def __init__(self, + full_score_weight: float = 0.3, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.full_score_weight = full_score_weight + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + gt_answer_weight = sample.get('gt_answer_weight') + if isinstance(gt_answer, str): + gt_answer = [gt_answer] + if gt_answer_weight is None: + gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer) + + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer, + 'gt_answer_weight': gt_answer_weight, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = [ + self._process_answer(answer) for answer in result['gt_answer'] + ] + answer_weight = result['gt_answer_weight'] + + weight_sum = 0 + for i, gt in enumerate(gt_answer): + if gt == pred_answer: + weight_sum += answer_weight[i] + vqa_acc = min(1.0, weight_sum / self.full_score_weight) + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer): + answer = answer.replace('\n', ' ') + answer = answer.replace('\t', ' ') + answer = answer.strip() + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer + + +@METRICS.register_module() +class ReportVQA(BaseMetric): + """Dump VQA result to the standard json format for VQA evaluation. + + Args: + file_path (str): The file path to save the result file. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'VQA' + + def __init__(self, + file_path: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + if not file_path.endswith('.json'): + raise ValueError('The output file must be a json file.') + self.file_path = file_path + + def process(self, data_batch, data_samples) -> None: + """transfer tensors in predictions to CPU.""" + for sample in data_samples: + question_id = sample['question_id'] + pred_answer = sample['pred_answer'] + + result = { + 'question_id': int(question_id), + 'answer': pred_answer, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Dump the result to json file.""" + mmengine.dump(results, self.file_path) + logger = MMLogger.get_current_instance() + logger.info(f'Results has been saved to {self.file_path}.') + return {} diff --git a/mmpretrain/models/__init__.py b/mmpretrain/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba05735b26a96cf532486f6f31d0d93bf6d30781 --- /dev/null +++ b/mmpretrain/models/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS, + build_backbone, build_classifier, build_head, build_loss, + build_neck) +from .classifiers import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .multimodal import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .retrievers import * # noqa: F401,F403 +from .selfsup import * # noqa: F401,F403 +from .tta import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone', + 'build_head', 'build_neck', 'build_loss', 'build_classifier' +] diff --git a/mmpretrain/models/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c41a876a642a73a4f788bdd3865f4259626e85c Binary files /dev/null and b/mmpretrain/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/__pycache__/builder.cpython-38.pyc b/mmpretrain/models/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca6c2ea8fbc421d18eba37af09ec6a2c9296bc38 Binary files /dev/null and b/mmpretrain/models/__pycache__/builder.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60e37fb7b6e15cadd0eef4a3c9c79c856fbf4247 --- /dev/null +++ b/mmpretrain/models/backbones/__init__.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .alexnet import AlexNet +from .beit import BEiTViT +from .conformer import Conformer +from .convmixer import ConvMixer +from .convnext import ConvNeXt +from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt +from .davit import DaViT +from .deit import DistilledVisionTransformer +from .deit3 import DeiT3 +from .densenet import DenseNet +from .edgenext import EdgeNeXt +from .efficientformer import EfficientFormer +from .efficientnet import EfficientNet +from .efficientnet_v2 import EfficientNetV2 +from .hivit import HiViT +from .hornet import HorNet +from .hrnet import HRNet +from .inception_v3 import InceptionV3 +from .lenet import LeNet5 +from .levit import LeViT +from .mixmim import MixMIMTransformer +from .mlp_mixer import MlpMixer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mobileone import MobileOne +from .mobilevit import MobileViT +from .mvit import MViT +from .poolformer import PoolFormer +from .regnet import RegNet +from .replknet import RepLKNet +from .repmlp import RepMLPNet +from .repvgg import RepVGG +from .res2net import Res2Net +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnet_cifar import ResNet_CIFAR +from .resnext import ResNeXt +from .revvit import RevVisionTransformer +from .riformer import RIFormer +from .seresnet import SEResNet +from .seresnext import SEResNeXt +from .shufflenet_v1 import ShuffleNetV1 +from .shufflenet_v2 import ShuffleNetV2 +from .sparse_convnext import SparseConvNeXt +from .sparse_resnet import SparseResNet +from .swin_transformer import SwinTransformer +from .swin_transformer_v2 import SwinTransformerV2 +from .t2t_vit import T2T_ViT +from .timm_backbone import TIMMBackbone +from .tinyvit import TinyViT +from .tnt import TNT +from .twins import PCPVT, SVT +from .van import VAN +from .vgg import VGG +from .vig import PyramidVig, Vig +from .vision_transformer import VisionTransformer +from .vit_eva02 import ViTEVA02 +from .vit_sam import ViTSAM +from .xcit import XCiT + +__all__ = [ + 'LeNet5', + 'AlexNet', + 'VGG', + 'RegNet', + 'ResNet', + 'ResNeXt', + 'ResNetV1d', + 'ResNeSt', + 'ResNet_CIFAR', + 'SEResNet', + 'SEResNeXt', + 'ShuffleNetV1', + 'ShuffleNetV2', + 'MobileNetV2', + 'MobileNetV3', + 'VisionTransformer', + 'SwinTransformer', + 'TNT', + 'TIMMBackbone', + 'T2T_ViT', + 'Res2Net', + 'RepVGG', + 'Conformer', + 'MlpMixer', + 'DistilledVisionTransformer', + 'PCPVT', + 'SVT', + 'EfficientNet', + 'EfficientNetV2', + 'ConvNeXt', + 'HRNet', + 'ResNetV1c', + 'ConvMixer', + 'EdgeNeXt', + 'CSPDarkNet', + 'CSPResNet', + 'CSPResNeXt', + 'CSPNet', + 'RepLKNet', + 'RepMLPNet', + 'PoolFormer', + 'RIFormer', + 'DenseNet', + 'VAN', + 'InceptionV3', + 'MobileOne', + 'EfficientFormer', + 'SwinTransformerV2', + 'MViT', + 'DeiT3', + 'HorNet', + 'MobileViT', + 'DaViT', + 'BEiTViT', + 'RevVisionTransformer', + 'MixMIMTransformer', + 'TinyViT', + 'LeViT', + 'Vig', + 'PyramidVig', + 'XCiT', + 'ViTSAM', + 'ViTEVA02', + 'HiViT', + 'SparseResNet', + 'SparseConvNeXt', +] diff --git a/mmpretrain/models/backbones/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a18015ba1008f366339d501214f73c9dc36c0617 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/alexnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/alexnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8252336c00e00abc793eb11a7293a5663019042 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/alexnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a5a615152cc30549993dc9119a62e59de5f60b4 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/beit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/beit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1df3344f2f3ea918b6d0c8a5deb9c2060759012e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/beit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/conformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/conformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30f2cfba565d203df5964dac416504857a937e26 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/conformer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/convmixer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/convmixer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77489e2b4e9c55950d7826bfe177b98f09e66019 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/convmixer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/convnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/convnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f2ab5471ab21c9fa8ac7b93f1b290794a6d0309 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/convnext.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/cspnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/cspnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e436a622fd452b53f6bb80ade3e803bb8fdd841f Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/cspnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/davit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/davit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca66272b4bfa46fc3e4b8eefe5c1046a1de38712 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/davit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/deit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/deit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fe28841c8abc0c0c88a5f0442178acef2d7a276 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/deit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/deit3.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/deit3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b937fdd2cb043f91bb42e642285b3027fa6b51f1 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/deit3.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/densenet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/densenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c503034f759b2cfbcb757d5502de1b4ffa65b601 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/densenet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/edgenext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/edgenext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ef53e8353e168a44e4916564dfaf1e0fcf6a540 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/edgenext.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4d9250ffde11fd47f2f62147a34db35b369749 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e77b8d52c0eff923fc0eaa14b17afec772f1512 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b782fb056e32738cf7b567a635bd06624f95f5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/hivit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/hivit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e537dbf1c7edf853ed1af9ae690ed60ecd7950 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hivit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/hornet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/hornet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3186c22a7964393ca3b4c36f7ff9eda9bb620ff Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hornet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/hrnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/hrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5a7d68884b2e20fbd361c9ce71a29ac6d32b627 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hrnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc09fd4f405b724744e9de9ff79cbc53bd48883e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/lenet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/lenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d6ecc0ebe6386b5c9cf85df200e87afe8fd2eb8 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/lenet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/levit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/levit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..060e55647661641895839d884f02f02b2bda9069 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/levit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mixmim.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mixmim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f18cced1cb4c0c754e2f23abfe1c0822991e96e6 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mixmim.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced907fb4d9214d35b4565f242f3ed2ef941246c Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e63399bd69b103ad71bf826014967c7214cfa682 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24d773a2b4305cdfcae87c82d5f3dd49183fea84 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobileone.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobileone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64bdcefa7ca0812e5ff42260616cb5a8adee58b0 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobileone.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7537bdce36b7680cf66ecc26bc5cb94b9bfc86b6 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mvit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mvit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f315d6b86d25569c9d8943b21efb26ef4e755246 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mvit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/poolformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/poolformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0bd0f5096e651c18b11a34747371fed5264378 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/poolformer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/regnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/regnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..411e00066295c655a7ff6dfa02e437c32f9b5980 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/regnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/replknet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/replknet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ef5ee77a958c77fb22e03866d43aa42cb8088b3 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/replknet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/repmlp.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/repmlp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..739b78287c2da67f7b970a24278ed5e93afddfec Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/repmlp.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/repvgg.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/repvgg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30ef3b932d009635d9e4b03e9c8edd6f899f6cd2 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/repvgg.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/res2net.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/res2net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1dbcc4def3590141f489a8c95d97a5eaba60245 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/res2net.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnest.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnest.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe35027e6548be56b8d8e2b34a04260e447a9cd9 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnest.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b1d529dd69bd33fe0346b5ae7ae7708d0f1c5eb Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a0ba27cdca835380c50713eaf8575ce8a92e974 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..301b46e4540092cec011402493ac7a9c703cd643 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnext.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/revvit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/revvit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dd40b6e7c27389f7a5157fa58d719ecf7880530 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/revvit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/riformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/riformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26bda9f9ae80d91236aca39cc3ac448a2a779122 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/riformer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/seresnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/seresnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6929011e417b1127ba452cc77848d68928d7f95 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/seresnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/seresnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/seresnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e7556b191544153b1ceb39a82ea0649d4b717c Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/seresnext.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96de19f10dfd4c35eba969159ae03ea6c66742d9 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d589fb82cae8f529c64aa309c9e8837da0330d35 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..278f96e9720e827c7a33b3199fdf837299cad7c5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31a1546e7fbe78b4207aa6f2f1ba91f2770500c5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6207cd2f5bab62b11a9323541eae626a93ee6487 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e07f5d0b8dac62c6b17dcc130ad1ff547bd23a1e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbff543a621817397f9afbd16e1fd678a2c1b4d2 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3ca7b55658448bad1c96c967fb30ce0ca83e30 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6ba4ca22ab620df1d8f85807dd7169c0bd4114c Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/tnt.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/tnt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c3acc5611e9bfdf0179549fc94b9437ca05eda5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/tnt.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/twins.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/twins.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08f9fc896877413f4dc9f5e78e27c80858df0074 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/twins.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/van.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/van.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f86dc3fc60bde313c4a13519c6959f1fac861152 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/van.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vgg.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vgg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66383d035cf49ff7bcf1932d7a395c682ce74d2a Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vgg.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vig.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vig.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa1ab0b901c39b4e67e1689c9db21fd4fdee9e5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vig.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30a69c3464095df54c17ae4d08d9e1720e321f27 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5deadcacffec2b134c8314608dc82dfd7bc87d9 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8129a755899798c5dafe11defd88b2f607a03f8 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/xcit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/xcit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4e5da85c50f64c34ec9ed8f20de44a0592900f3 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/xcit.cpython-38.pyc differ diff --git a/mmpretrain/models/backbones/alexnet.py b/mmpretrain/models/backbones/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c2891fdd2c878e243331f572f6e3e562232d46 --- /dev/null +++ b/mmpretrain/models/backbones/alexnet.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class AlexNet(BaseBackbone): + """`AlexNet `_ backbone. + + The input for AlexNet is a 224x224 RGB image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return (x, ) diff --git a/mmpretrain/models/backbones/base_backbone.py b/mmpretrain/models/backbones/base_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..751aa956ba2ad178ea9e40875b6e610ee7bbbcd3 --- /dev/null +++ b/mmpretrain/models/backbones/base_backbone.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmengine.model import BaseModule + + +class BaseBackbone(BaseModule, metaclass=ABCMeta): + """Base backbone. + + This class defines the basic functions of a backbone. Any backbone that + inherits this class should at least define its own `forward` function. + """ + + def __init__(self, init_cfg=None): + super(BaseBackbone, self).__init__(init_cfg) + + @abstractmethod + def forward(self, x): + """Forward computation. + + Args: + x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + """ + pass + + def train(self, mode=True): + """Set module status before forward computation. + + Args: + mode (bool): Whether it is train_mode or test_mode + """ + super(BaseBackbone, self).train(mode) diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..8f64ae2029b8be47d938fdef25aed9c0058ef307 --- /dev/null +++ b/mmpretrain/models/backbones/beit.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed, + resize_relative_position_bias_table, to_2tuple) +from .vision_transformer import TransformerEncoderLayer, VisionTransformer + + +class RelativePositionBias(BaseModule): + """Relative Position Bias. + + This module is copied from + https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209. + + Args: + window_size (Sequence[int]): The window size of the relative + position bias. + num_heads (int): The number of head in multi-head attention. + with_cls_token (bool): To indicate the backbone has cls_token or not. + Defaults to True. + """ + + def __init__( + self, + window_size: Sequence[int], + num_heads: int, + with_cls_token: bool = True, + ) -> None: + super().__init__() + self.window_size = window_size + if with_cls_token: + num_extra_tokens = 3 + else: + num_extra_tokens = 0 + # cls to token & token to cls & cls to cls + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1) + num_extra_tokens + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each + # token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] -\ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + if with_cls_token: + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1, ) * 2, + dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum( + -1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + else: + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1], ) * 2, + dtype=relative_coords.dtype) + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + self.register_buffer('relative_position_index', + relative_position_index) + + def forward(self) -> torch.Tensor: + # Wh*Ww,Wh*Ww,nH + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) + return relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class BEiTTransformerEncoderLayer(TransformerEncoderLayer): + """Implements one encoder layer in BEiT. + + Comparing with conventional ``TransformerEncoderLayer``, this module + adds weights to the shortcut connection. In addition, ``BEiTAttention`` + is used to replace the original ``MultiheadAttention`` in + ``TransformerEncoderLayer``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. 1 means no scaling. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + window_size (tuple[int]): The height and width of the window. + Defaults to None. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='LN'). + attn_cfg (dict): The configuration for the attention layer. + Defaults to an empty dict. + ffn_cfg (dict): The configuration for the ffn layer. + Defaults to ``dict(add_identity=False)``. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + layer_scale_init_value: float, + window_size: Tuple[int, int], + use_rel_pos_bias: bool, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + bias: Union[str, bool] = 'qv_bias', + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + attn_cfg: dict = dict(), + ffn_cfg: dict = dict(add_identity=False), + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + attn_cfg = { + 'window_size': window_size, + 'use_rel_pos_bias': use_rel_pos_bias, + 'qk_scale': None, + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'attn_drop': attn_drop_rate, + 'proj_drop': drop_rate, + 'bias': bias, + **attn_cfg, + } + self.attn = BEiTAttention(**attn_cfg) + + ffn_cfg = { + 'embed_dims': embed_dims, + 'feedforward_channels': feedforward_channels, + 'num_fcs': num_fcs, + 'ffn_drop': drop_rate, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path_rate), + 'act_cfg': act_cfg, + **ffn_cfg, + } + self.ffn = FFN(**ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x: torch.Tensor, + rel_pos_bias: torch.Tensor) -> torch.Tensor: + if self.gamma_1 is None: + x = x + self.drop_path( + self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.ffn(self.ln2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn( + self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x))) + return x + + +@MODELS.register_module() +class BEiTViT(VisionTransformer): + """Backbone for BEiT. + + A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers + `_ + A PyTorch implement of : `BEiT v2: Masked Image Modeling with + Vector-Quantized Visual Tokenizers `_ + + Args: + arch (str | dict): BEiT architecture. If use string, choose from + 'base', 'large'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + use_abs_pos_emb (bool): Use position embedding like vanilla ViT. + Defaults to False. + use_rel_pos_bias (bool): Use relative position embedding in each + transformer encoder layer. Defaults to True. + use_shared_rel_pos_bias (bool): Use shared relative position embedding, + all transformer encoder layers share the same relative position + embedding. Defaults to False. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0, + drop_path_rate=0, + bias='qv_bias', + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=False, + out_type='avg_featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0.1, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + else: + self.pos_embed = None + self.drop_after_pos = nn.Dropout(p=drop_rate) + + assert not (use_rel_pos_bias and use_shared_rel_pos_bias), ( + '`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set ' + 'to True at the same time') + self.use_rel_pos_bias = use_rel_pos_bias + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_resolution, + num_heads=self.arch_settings['num_heads']) + else: + self.rel_pos_bias = None + self._register_load_state_dict_pre_hook( + self._prepare_relative_position_bias_table) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + window_size=self.patch_resolution, + use_rel_pos_bias=use_rel_pos_bias, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + bias=bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + if out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + rel_pos_bias = self.rel_pos_bias() \ + if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, + **kwargs): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + + if self.use_rel_pos_bias and 'rel_pos_bias.relative_position_bias_table' in state_dict: # noqa:E501 + logger.info('Expand the shared relative position embedding to ' + 'each transformer block.') + rel_pos_bias = state_dict[ + 'rel_pos_bias.relative_position_bias_table'] + for i in range(self.num_layers): + state_dict[ + f'layers.{i}.attn.relative_position_bias_table'] = \ + rel_pos_bias.clone() + state_dict.pop('rel_pos_bias.relative_position_bias_table') + state_dict.pop('rel_pos_bias.relative_position_index') + + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'relative_position_bias_table' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + rel_pos_bias_pretrained = state_dict[ckpt_key] + rel_pos_bias_current = state_dict_model[key] + L1, nH1 = rel_pos_bias_pretrained.size() + L2, nH2 = rel_pos_bias_current.size() + src_size = int((L1 - 3)**0.5) + dst_size = int((L2 - 3)**0.5) + if L1 != L2: + extra_tokens = rel_pos_bias_pretrained[-3:, :] + rel_pos_bias = rel_pos_bias_pretrained[:-3, :] + + new_rel_pos_bias = resize_relative_position_bias_table( + src_size, dst_size, rel_pos_bias, nH1) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + logger.info('Resize the relative_position_bias_table from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos_bias.shape}') + state_dict[ckpt_key] = new_rel_pos_bias + + # The index buffer need to be re-generated. + index_buffer = ckpt_key.replace('bias_table', 'index') + if index_buffer in state_dict: + del state_dict[index_buffer] diff --git a/mmpretrain/models/backbones/conformer.py b/mmpretrain/models/backbones/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..eda72b0595b6923a7f1f563ae7186ca533f85023 --- /dev/null +++ b/mmpretrain/models/backbones/conformer.py @@ -0,0 +1,621 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import AdaptivePadding +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class ConvBlock(BaseModule): + """Basic convluation block used in Conformer. + + This block includes three convluation modules, and supports three new + functions: + 1. Returns the output of both the final layers and the second convluation + module. + 2. Fuses the input of the second convluation module with an extra input + feature map. + 3. Supports to add an extra convluation module to the identity connection. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + stride (int): The stride of the second convluation module. + Defaults to 1. + groups (int): The groups of the second convluation module. + Defaults to 1. + drop_path_rate (float): The rate of the DropPath layer. Defaults to 0. + with_residual_conv (bool): Whether to add an extra convluation module + to the identity connection. Defaults to False. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN', eps=1e-6)``. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='ReLU', inplace=True))``. + init_cfg (dict, optional): The extra config to initialize the module. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + groups=1, + drop_path_rate=0., + with_residual_conv=False, + norm_cfg=dict(type='BN', eps=1e-6), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(ConvBlock, self).__init__(init_cfg=init_cfg) + + expansion = 4 + mid_channels = out_channels // expansion + + self.conv1 = nn.Conv2d( + in_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1] + self.act1 = build_activation_layer(act_cfg) + + self.conv2 = nn.Conv2d( + mid_channels, + mid_channels, + kernel_size=3, + stride=stride, + groups=groups, + padding=1, + bias=False) + self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1] + self.act2 = build_activation_layer(act_cfg) + + self.conv3 = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn3 = build_norm_layer(norm_cfg, out_channels)[1] + self.act3 = build_activation_layer(act_cfg) + + if with_residual_conv: + self.residual_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1] + + self.with_residual_conv = with_residual_conv + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x, fusion_features=None, out_conv2=True): + identity = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) if fusion_features is None else self.conv2( + x + fusion_features) + x = self.bn2(x) + x2 = self.act2(x) + + x = self.conv3(x2) + x = self.bn3(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.with_residual_conv: + identity = self.residual_conv(identity) + identity = self.residual_bn(identity) + + x += identity + x = self.act3(x) + + if out_conv2: + return x, x2 + else: + return x + + +class FCUDown(BaseModule): + """CNN feature maps -> Transformer patch embeddings.""" + + def __init__(self, + in_channels, + out_channels, + down_stride, + with_cls_token=True, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(FCUDown, self).__init__(init_cfg=init_cfg) + self.down_stride = down_stride + self.with_cls_token = with_cls_token + + self.conv_project = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.sample_pooling = nn.AvgPool2d( + kernel_size=down_stride, stride=down_stride) + + self.ln = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x, x_t): + x = self.conv_project(x) # [N, C, H, W] + + x = self.sample_pooling(x).flatten(2).transpose(1, 2) + x = self.ln(x) + x = self.act(x) + + if self.with_cls_token: + x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) + + return x + + +class FCUUp(BaseModule): + """Transformer patch embeddings -> CNN feature maps.""" + + def __init__(self, + in_channels, + out_channels, + up_stride, + with_cls_token=True, + norm_cfg=dict(type='BN', eps=1e-6), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(FCUUp, self).__init__(init_cfg=init_cfg) + + self.up_stride = up_stride + self.with_cls_token = with_cls_token + + self.conv_project = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.bn = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x, H, W): + B, _, C = x.shape + # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14] + if self.with_cls_token: + x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) + else: + x_r = x.transpose(1, 2).reshape(B, C, H, W) + + x_r = self.act(self.bn(self.conv_project(x_r))) + + return F.interpolate( + x_r, size=(H * self.up_stride, W * self.up_stride)) + + +class ConvTransBlock(BaseModule): + """Basic module for Conformer. + + This module is a fusion of CNN block transformer encoder block. + + Args: + in_channels (int): The number of input channels in conv blocks. + out_channels (int): The number of output channels in conv blocks. + embed_dims (int): The embedding dimension in transformer blocks. + conv_stride (int): The stride of conv2d layers. Defaults to 1. + groups (int): The groups of conv blocks. Defaults to 1. + with_residual_conv (bool): Whether to add a conv-bn layer to the + identity connect in the conv block. Defaults to False. + down_stride (int): The stride of the downsample pooling layer. + Defaults to 4. + num_heads (int): The number of heads in transformer attention layers. + Defaults to 12. + mlp_ratio (float): The expansion ratio in transformer FFN module. + Defaults to 4. + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + with_cls_token (bool): Whether use class token or not. + Defaults to True. + drop_rate (float): The dropout rate of the output projection and + FFN in the transformer block. Defaults to 0. + attn_drop_rate (float): The dropout rate after the attention + calculation in the transformer block. Defaults to 0. + drop_path_rate (bloat): The drop path rate in both the conv block + and the transformer block. Defaults to 0. + last_fusion (bool): Whether this block is the last stage. If so, + downsample the fusion feature map. + init_cfg (dict, optional): The extra config to initialize the module. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + embed_dims, + conv_stride=1, + groups=1, + with_residual_conv=False, + down_stride=4, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + with_cls_token=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + last_fusion=False, + init_cfg=None): + super(ConvTransBlock, self).__init__(init_cfg=init_cfg) + expansion = 4 + self.cnn_block = ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + with_residual_conv=with_residual_conv, + stride=conv_stride, + groups=groups) + + if last_fusion: + self.fusion_block = ConvBlock( + in_channels=out_channels, + out_channels=out_channels, + stride=2, + with_residual_conv=True, + groups=groups, + drop_path_rate=drop_path_rate) + else: + self.fusion_block = ConvBlock( + in_channels=out_channels, + out_channels=out_channels, + groups=groups, + drop_path_rate=drop_path_rate) + + self.squeeze_block = FCUDown( + in_channels=out_channels // expansion, + out_channels=embed_dims, + down_stride=down_stride, + with_cls_token=with_cls_token) + + self.expand_block = FCUUp( + in_channels=embed_dims, + out_channels=out_channels // expansion, + up_stride=down_stride, + with_cls_token=with_cls_token) + + self.trans_block = TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(embed_dims * mlp_ratio), + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + qkv_bias=qkv_bias, + norm_cfg=dict(type='LN', eps=1e-6)) + + self.down_stride = down_stride + self.embed_dim = embed_dims + self.last_fusion = last_fusion + + def forward(self, cnn_input, trans_input): + x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True) + + _, _, H, W = x_conv2.shape + + # Convert the feature map of conv2 to transformer embedding + # and concat with class token. + conv2_embedding = self.squeeze_block(x_conv2, trans_input) + + trans_output = self.trans_block(conv2_embedding + trans_input) + + # Convert the transformer output embedding to feature map + trans_features = self.expand_block(trans_output, H // self.down_stride, + W // self.down_stride) + x = self.fusion_block( + x, fusion_features=trans_features, out_conv2=False) + + return x, trans_output + + +@MODELS.register_module() +class Conformer(BaseBackbone): + """Conformer backbone. + + A PyTorch implementation of : `Conformer: Local Features Coupling Global + Representations for Visual Recognition `_ + + Args: + arch (str | dict): Conformer architecture. Defaults to 'tiny'. + patch_size (int): The patch size. Defaults to 16. + base_channels (int): The base number of channels in CNN network. + Defaults to 64. + mlp_ratio (float): The expansion ratio of FFN network in transformer + block. Defaults to 4. + with_cls_token (bool): Whether use class token or not. + Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 384, + 'channel_ratio': 1, + 'num_heads': 6, + 'depths': 12 + }), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 384, + 'channel_ratio': 4, + 'num_heads': 6, + 'depths': 12 + }), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 576, + 'channel_ratio': 6, + 'num_heads': 9, + 'depths': 12 + }), + } # yapf: disable + + _version = 1 + + def __init__(self, + arch='tiny', + patch_size=16, + base_channels=64, + mlp_ratio=4., + qkv_bias=True, + with_cls_token=True, + drop_path_rate=0., + norm_eval=True, + frozen_stages=0, + out_indices=-1, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'channel_ratio' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.num_features = self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.channel_ratio = self.arch_settings['channel_ratio'] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.depths + index + 1 + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + self.with_cls_token = with_cls_token + if self.with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + # stochastic depth decay rule + self.trans_dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, self.depths) + ] + + # Stem stage: get the feature maps by conv block + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, + bias=False) # 1 / 2 [112, 112] + self.bn1 = nn.BatchNorm2d(64) + self.act1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56] + + assert patch_size % 16 == 0, 'The patch size of Conformer must ' \ + 'be divisible by 16.' + trans_down_stride = patch_size // 4 + + # To solve the issue #680 + # Auto pad the feature map to be divisible by trans_down_stride + self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride) + + # 1 stage + stage1_channels = int(base_channels * self.channel_ratio) + self.conv_1 = ConvBlock( + in_channels=64, + out_channels=stage1_channels, + with_residual_conv=True, + stride=1) + self.trans_patch_conv = nn.Conv2d( + 64, + self.embed_dims, + kernel_size=trans_down_stride, + stride=trans_down_stride, + padding=0) + + self.trans_1 = TransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=int(self.embed_dims * mlp_ratio), + drop_path_rate=self.trans_dpr[0], + qkv_bias=qkv_bias, + norm_cfg=dict(type='LN', eps=1e-6)) + + # 2~4 stage + init_stage = 2 + fin_stage = self.depths // 3 + 1 + for i in range(init_stage, fin_stage): + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=stage1_channels, + out_channels=stage1_channels, + embed_dims=self.embed_dims, + conv_stride=1, + with_residual_conv=False, + down_stride=trans_down_stride, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token)) + + stage2_channels = int(base_channels * self.channel_ratio * 2) + # 5~8 stage + init_stage = fin_stage # 5 + fin_stage = fin_stage + self.depths // 3 # 9 + for i in range(init_stage, fin_stage): + if i == init_stage: + conv_stride = 2 + in_channels = stage1_channels + else: + conv_stride = 1 + in_channels = stage2_channels + + with_residual_conv = True if i == init_stage else False + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=in_channels, + out_channels=stage2_channels, + embed_dims=self.embed_dims, + conv_stride=conv_stride, + with_residual_conv=with_residual_conv, + down_stride=trans_down_stride // 2, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token)) + + stage3_channels = int(base_channels * self.channel_ratio * 2 * 2) + # 9~12 stage + init_stage = fin_stage # 9 + fin_stage = fin_stage + self.depths // 3 # 13 + for i in range(init_stage, fin_stage): + if i == init_stage: + conv_stride = 2 + in_channels = stage2_channels + with_residual_conv = True + else: + conv_stride = 1 + in_channels = stage3_channels + with_residual_conv = False + + last_fusion = (i == self.depths) + + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=in_channels, + out_channels=stage3_channels, + embed_dims=self.embed_dims, + conv_stride=conv_stride, + with_residual_conv=with_residual_conv, + down_stride=trans_down_stride // 4, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token, + last_fusion=last_fusion)) + self.fin_stage = fin_stage + + self.pooling = nn.AdaptiveAvgPool2d(1) + self.trans_norm = nn.LayerNorm(self.embed_dims) + + if self.with_cls_token: + trunc_normal_(self.cls_token, std=.02) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def init_weights(self): + super(Conformer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + self.apply(self._init_weights) + + def forward(self, x): + output = [] + B = x.shape[0] + if self.with_cls_token: + cls_tokens = self.cls_token.expand(B, -1, -1) + + # stem + x_base = self.maxpool(self.act1(self.bn1(self.conv1(x)))) + x_base = self.auto_pad(x_base) + + # 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56] + x = self.conv_1(x_base, out_conv2=False) + x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2) + if self.with_cls_token: + x_t = torch.cat([cls_tokens, x_t], dim=1) + x_t = self.trans_1(x_t) + + # 2 ~ final + for i in range(2, self.fin_stage): + stage = getattr(self, f'conv_trans_{i}') + x, x_t = stage(x, x_t) + if i in self.out_indices: + if self.with_cls_token: + output.append([ + self.pooling(x).flatten(1), + self.trans_norm(x_t)[:, 0] + ]) + else: + # if no class token, use the mean patch token + # as the transformer feature. + output.append([ + self.pooling(x).flatten(1), + self.trans_norm(x_t).mean(dim=1) + ]) + + return tuple(output) diff --git a/mmpretrain/models/backbones/convmixer.py b/mmpretrain/models/backbones/convmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..480050d5ce1aa29f190dbc24ec1413573d541cb1 --- /dev/null +++ b/mmpretrain/models/backbones/convmixer.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer, + build_norm_layer) +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class Residual(nn.Module): + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +@MODELS.register_module() +class ConvMixer(BaseBackbone): + """ConvMixer. . + + A PyTorch implementation of : `Patches Are All You Need? + `_ + + Modified from the `official repo + `_ + and `timm + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvMixer.arch_settings``. And if dict, it + should include the following two keys: + + - embed_dims (int): The dimensions of patch embedding. + - depth (int): Number of repetitions of ConvMixer Layer. + - patch_size (int): The patch size. + - kernel_size (int): The kernel size of depthwise conv layers. + + Defaults to '768/32'. + in_channels (int): Number of input image channels. Defaults to 3. + patch_size (int): The size of one patch in the patch embed layer. + Defaults to 7. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='GELU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '768/32': { + 'embed_dims': 768, + 'depth': 32, + 'patch_size': 7, + 'kernel_size': 7 + }, + '1024/20': { + 'embed_dims': 1024, + 'depth': 20, + 'patch_size': 14, + 'kernel_size': 9 + }, + '1536/20': { + 'embed_dims': 1536, + 'depth': 20, + 'patch_size': 7, + 'kernel_size': 9 + }, + } + + def __init__(self, + arch='768/32', + in_channels=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = { + 'embed_dims', 'depth', 'patch_size', 'kernel_size' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.embed_dims = arch['embed_dims'] + self.depth = arch['depth'] + self.patch_size = arch['patch_size'] + self.kernel_size = arch['kernel_size'] + self.act = build_activation_layer(act_cfg) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.depth + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.embed_dims, + kernel_size=self.patch_size, + stride=self.patch_size), self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1]) + + # Set conv2d according to torch version + convfunc = nn.Conv2d + if digit_version(torch.__version__) < digit_version('1.9.0'): + convfunc = Conv2dAdaptivePadding + + # Repetitions of ConvMixer Layer + self.stages = nn.Sequential(*[ + nn.Sequential( + Residual( + nn.Sequential( + convfunc( + self.embed_dims, + self.embed_dims, + self.kernel_size, + groups=self.embed_dims, + padding='same'), self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1])), + nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1), + self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1]) + for _ in range(self.depth) + ]) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + outs.append(x) + + # x = self.pooling(x).flatten(1) + return tuple(outs) + + def train(self, mode=True): + super(ConvMixer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False diff --git a/mmpretrain/models/backbones/convnext.py b/mmpretrain/models/backbones/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..6a954f5b980186a86565a228669c6917bda14f68 --- /dev/null +++ b/mmpretrain/models/backbones/convnext.py @@ -0,0 +1,412 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import GRN, build_norm_layer +from .base_backbone import BaseBackbone + + +class ConvNeXtBlock(BaseModule): + """ConvNeXt Block. + + Args: + in_channels (int): The number of input channels. + dw_conv_cfg (dict): Config of depthwise convolution. + Defaults to ``dict(kernel_size=7, padding=3)``. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + mlp_ratio (float): The expansion ratio in both pointwise convolution. + Defaults to 4. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. More details can be found in the note. + Defaults to True. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + + Note: + There are two equivalent implementations: + + 1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU + -> Linear; Permute back + + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def __init__(self, + in_channels, + dw_conv_cfg=dict(kernel_size=7, padding=3), + norm_cfg=dict(type='LN2d', eps=1e-6), + act_cfg=dict(type='GELU'), + mlp_ratio=4., + linear_pw_conv=True, + drop_path_rate=0., + layer_scale_init_value=1e-6, + use_grn=False, + with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.depthwise_conv = nn.Conv2d( + in_channels, in_channels, groups=in_channels, **dw_conv_cfg) + + self.linear_pw_conv = linear_pw_conv + self.norm = build_norm_layer(norm_cfg, in_channels) + + mid_channels = int(mlp_ratio * in_channels) + if self.linear_pw_conv: + # Use linear layer to do pointwise conv. + pw_conv = nn.Linear + else: + pw_conv = partial(nn.Conv2d, kernel_size=1) + + self.pointwise_conv1 = pw_conv(in_channels, mid_channels) + self.act = MODELS.build(act_cfg) + self.pointwise_conv2 = pw_conv(mid_channels, in_channels) + + if use_grn: + self.grn = GRN(mid_channels) + else: + self.grn = None + + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((in_channels)), + requires_grad=True) if layer_scale_init_value > 0 else None + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class ConvNeXt(BaseBackbone): + """ConvNeXt v1&v2 backbone. + + A PyTorch implementation of `A ConvNet for the 2020s + `_ and + `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders + `_ + + Modified from the `official repo + `_ + and `timm + `_. + + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + arch_settings = { + 'atto': { + 'depths': [2, 2, 6, 2], + 'channels': [40, 80, 160, 320] + }, + 'femto': { + 'depths': [2, 2, 6, 2], + 'channels': [48, 96, 192, 384] + }, + 'pico': { + 'depths': [2, 2, 6, 2], + 'channels': [64, 128, 256, 512] + }, + 'nano': { + 'depths': [2, 2, 8, 2], + 'channels': [80, 160, 320, 640] + }, + 'tiny': { + 'depths': [3, 3, 9, 3], + 'channels': [96, 192, 384, 768] + }, + 'small': { + 'depths': [3, 3, 27, 3], + 'channels': [96, 192, 384, 768] + }, + 'base': { + 'depths': [3, 3, 27, 3], + 'channels': [128, 256, 512, 1024] + }, + 'large': { + 'depths': [3, 3, 27, 3], + 'channels': [192, 384, 768, 1536] + }, + 'xlarge': { + 'depths': [3, 3, 27, 3], + 'channels': [256, 512, 1024, 2048] + }, + 'huge': { + 'depths': [3, 3, 27, 3], + 'channels': [352, 704, 1408, 2816] + } + } + + def __init__(self, + arch='tiny', + in_channels=3, + stem_patch_size=4, + norm_cfg=dict(type='LN2d', eps=1e-6), + act_cfg=dict(type='GELU'), + linear_pw_conv=True, + use_grn=False, + drop_path_rate=0., + layer_scale_init_value=1e-6, + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + with_cp=False, + init_cfg=[ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + ConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + if i in self.out_indices: + norm_layer = build_norm_layer(norm_cfg, channels) + self.add_module(f'norm{i}', norm_layer) + + self._freeze_stages() + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap).flatten(1)) + else: + outs.append(norm_layer(x)) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.downsample_layers[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(ConvNeXt, self).train(mode) + self._freeze_stages() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + + max_layer_id = 12 if self.depths[-2] > 9 else 6 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id + 1, max_layer_id + 2 + + param_name = param_name[len(prefix):] + if param_name.startswith('downsample_layers'): + stage_id = int(param_name.split('.')[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + else: # stage_id == 3: + layer_id = max_layer_id + + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = int(param_name.split('.')[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + else: # stage_id == 3: + layer_id = max_layer_id + + # final norm layer + else: + layer_id = max_layer_id + 1 + + return layer_id, max_layer_id + 2 diff --git a/mmpretrain/models/backbones/cspnet.py b/mmpretrain/models/backbones/cspnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7492e97702c28861dcce2808207a35e67f32f752 --- /dev/null +++ b/mmpretrain/models/backbones/cspnet.py @@ -0,0 +1,679 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import to_ntuple +from .resnet import Bottleneck as ResNetBottleneck +from .resnext import Bottleneck as ResNeXtBottleneck + +eps = 1.0e-5 + + +class DarknetBottleneck(BaseModule): + """The basic bottleneck block used in Darknet. Each DarknetBottleneck + consists of two ConvModules and the input is added to the final output. + Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer + has filter size of 1x1 and the second one has the filter size of 3x3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. + Defaults to 4. + add_identity (bool): Whether to add identity to the out. + Defaults to True. + use_depthwise (bool): Whether to use depthwise separable convolution. + Defaults to False. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + drop_path_rate (float): The ratio of the drop path layer. Default: 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN', eps=1e-5)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='Swish')``. + """ + + def __init__(self, + in_channels, + out_channels, + expansion=2, + add_identity=True, + use_depthwise=False, + conv_cfg=None, + drop_path_rate=0, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + init_cfg=None): + super().__init__(init_cfg) + hidden_channels = int(out_channels / expansion) + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + self.conv1 = ConvModule( + in_channels, + hidden_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = conv( + hidden_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.add_identity = \ + add_identity and in_channels == out_channels + + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.drop_path(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPStage(BaseModule): + """Cross Stage Partial Stage. + + .. code:: text + + Downsample Convolution (optional) + | + | + Expand Convolution + | + | + Split to xa, xb + | \ + | \ + | blocks(xb) + | / + | / transition + | / + Concat xa, blocks(xb) + | + Transition Convolution + + Args: + block_fn (nn.module): The basic block function in the Stage. + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + has_downsampler (bool): Whether to add a downsampler in the stage. + Default: False. + down_growth (bool): Whether to expand the channels in the + downsampler layer of the stage. Default: False. + expand_ratio (float): The expand ratio to adjust the number of + channels of the expand conv layer. Default: 0.5 + bottle_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + block_dpr (float): The ratio of the drop path layer in the + blocks of the stage. Default: 0. + num_blocks (int): Number of blocks. Default: 1 + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', inplace=True) + """ + + def __init__(self, + block_fn, + in_channels, + out_channels, + has_downsampler=True, + down_growth=False, + expand_ratio=0.5, + bottle_ratio=2, + num_blocks=1, + block_dpr=0, + block_args={}, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + init_cfg=None): + super().__init__(init_cfg) + # grow downsample channels to output channels + down_channels = out_channels if down_growth else in_channels + block_dpr = to_ntuple(num_blocks)(block_dpr) + + if has_downsampler: + self.downsample_conv = ConvModule( + in_channels=in_channels, + out_channels=down_channels, + kernel_size=3, + stride=2, + padding=1, + groups=32 if block_fn is ResNeXtBottleneck else 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.downsample_conv = nn.Identity() + + exp_channels = int(down_channels * expand_ratio) + self.expand_conv = ConvModule( + in_channels=down_channels, + out_channels=exp_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg if block_fn is DarknetBottleneck else None) + + assert exp_channels % 2 == 0, \ + 'The channel number before blocks must be divisible by 2.' + block_channels = exp_channels // 2 + blocks = [] + for i in range(num_blocks): + block_cfg = dict( + in_channels=block_channels, + out_channels=block_channels, + expansion=bottle_ratio, + drop_path_rate=block_dpr[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **block_args) + blocks.append(block_fn(**block_cfg)) + self.blocks = Sequential(*blocks) + self.atfer_blocks_conv = ConvModule( + block_channels, + block_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.final_conv = ConvModule( + 2 * block_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.downsample_conv(x) + x = self.expand_conv(x) + + split = x.shape[1] // 2 + xa, xb = x[:, :split], x[:, split:] + + xb = self.blocks(xb) + xb = self.atfer_blocks_conv(xb).contiguous() + + x_final = torch.cat((xa, xb), dim=1) + return self.final_conv(x_final) + + +class CSPNet(BaseModule): + """The abstract CSP Network class. + + A Pytorch implementation of `CSPNet: A New Backbone that can Enhance + Learning Capability of CNN `_ + + This class is an abstract class because the Cross Stage Partial Network + (CSPNet) is a kind of universal network structure, and you + network block to implement networks like CSPResNet, CSPResNeXt and + CSPDarkNet. + + Args: + arch (dict): The architecture of the CSPNet. + It should have the following keys: + + - block_fn (Callable): A function or class to return a block + module, and it should accept at least ``in_channels``, + ``out_channels``, ``expansion``, ``drop_path_rate``, ``norm_cfg`` + and ``act_cfg``. + - in_channels (Tuple[int]): The number of input channels of each + stage. + - out_channels (Tuple[int]): The number of output channels of each + stage. + - num_blocks (Tuple[int]): The number of blocks in each stage. + - expansion_ratio (float | Tuple[float]): The expansion ratio in + the expand convolution of each stage. Defaults to 0.5. + - bottle_ratio (float | Tuple[float]): The expansion ratio of + blocks in each stage. Defaults to 2. + - has_downsampler (bool | Tuple[bool]): Whether to add a + downsample convolution in each stage. Defaults to True + - down_growth (bool | Tuple[bool]): Whether to expand the channels + in the downsampler layer of each stage. Defaults to False. + - block_args (dict | Tuple[dict], optional): The extra arguments to + the blocks in each stage. Defaults to None. + + stem_fn (Callable): A function or class to return a stem module. + And it should accept ``in_channels``. + in_channels (int): Number of input image channels. Defaults to 3. + out_indices (int | Sequence[int]): Output from which stages. + Defaults to -1, which means the last stage. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict, optional): The config dict for conv layers in blocks. + Defaults to None, which means use Conv2d. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN', eps=1e-5)``. + act_cfg (dict): The config dict for activation functions. + Defaults to ``dict(type='LeakyReLU', inplace=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict, optional): The initialization settings. + Defaults to ``dict(type='Kaiming', layer='Conv2d'))``. + + Example: + >>> from functools import partial + >>> import torch + >>> import torch.nn as nn + >>> from mmpretrain.models import CSPNet + >>> from mmpretrain.models.backbones.resnet import Bottleneck + >>> + >>> # A simple example to build CSPNet. + >>> arch = dict( + ... block_fn=Bottleneck, + ... in_channels=[32, 64], + ... out_channels=[64, 128], + ... num_blocks=[3, 4] + ... ) + >>> stem_fn = partial(nn.Conv2d, out_channels=32, kernel_size=3) + >>> model = CSPNet(arch=arch, stem_fn=stem_fn, out_indices=(0, 1)) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outs = model(inputs) + >>> for out in outs: + ... print(out.shape) + ... + (1, 64, 111, 111) + (1, 128, 56, 56) + """ + + def __init__(self, + arch, + stem_fn, + in_channels=3, + out_indices=-1, + frozen_stages=-1, + drop_path_rate=0., + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict(type='Kaiming', layer='Conv2d')): + super().__init__(init_cfg=init_cfg) + self.arch = self.expand_arch(arch) + self.num_stages = len(self.arch['in_channels']) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + if frozen_stages not in range(-1, self.num_stages): + raise ValueError('frozen_stages must be in range(-1, ' + f'{self.num_stages}). But received ' + f'{frozen_stages}') + self.frozen_stages = frozen_stages + + self.stem = stem_fn(in_channels) + + stages = [] + depths = self.arch['num_blocks'] + dpr = torch.linspace(0, drop_path_rate, sum(depths)).split(depths) + + for i in range(self.num_stages): + stage_cfg = {k: v[i] for k, v in self.arch.items()} + csp_stage = CSPStage( + **stage_cfg, + block_dpr=dpr[i].tolist(), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=init_cfg) + stages.append(csp_stage) + self.stages = Sequential(*stages) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + @staticmethod + def expand_arch(arch): + num_stages = len(arch['in_channels']) + + def to_tuple(x, name=''): + if isinstance(x, (list, tuple)): + assert len(x) == num_stages, \ + f'The length of {name} ({len(x)}) does not ' \ + f'equals to the number of stages ({num_stages})' + return tuple(x) + else: + return (x, ) * num_stages + + full_arch = {k: to_tuple(v, k) for k, v in arch.items()} + if 'block_args' not in full_arch: + full_arch['block_args'] = to_tuple({}) + return full_arch + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(CSPNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def forward(self, x): + outs = [] + + x = self.stem(x) + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + +@MODELS.register_module() +class CSPDarkNet(CSPNet): + """CSP-Darknet backbone used in YOLOv4. + + Args: + depth (int): Depth of CSP-Darknet. Default: 53. + in_channels (int): Number of input image channels. Default: 3. + out_indices (Sequence[int]): Output from which stages. + Default: (3, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmpretrain.models import CSPDarkNet + >>> import torch + >>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 64, 208, 208) + (1, 128, 104, 104) + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + arch_settings = { + 53: + dict( + block_fn=DarknetBottleneck, + in_channels=(32, 64, 128, 256, 512), + out_channels=(64, 128, 256, 512, 1024), + num_blocks=(1, 2, 8, 8, 4), + expand_ratio=(2, 1, 1, 1, 1), + bottle_ratio=(2, 1, 1, 1, 1), + has_downsampler=True, + down_growth=True, + ), + } + + def __init__(self, + depth, + in_channels=3, + out_indices=(4, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu')): + + assert depth in self.arch_settings, 'depth must be one of ' \ + f'{list(self.arch_settings.keys())}, but get {depth}.' + + super().__init__( + arch=self.arch_settings[depth], + stem_fn=self._make_stem_layer, + in_channels=in_channels, + out_indices=out_indices, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + norm_eval=norm_eval, + init_cfg=init_cfg) + + def _make_stem_layer(self, in_channels): + """using a stride=1 conv as the stem in CSPDarknet.""" + # `stem_channels` equals to the `in_channels` in the first stage. + stem_channels = self.arch['in_channels'][0] + stem = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + return stem + + +@MODELS.register_module() +class CSPResNet(CSPNet): + """CSP-ResNet backbone. + + Args: + depth (int): Depth of CSP-ResNet. Default: 50. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmpretrain.models import CSPResNet + >>> import torch + >>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 128, 104, 104) + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + arch_settings = { + 50: + dict( + block_fn=ResNetBottleneck, + in_channels=(64, 128, 256, 512), + out_channels=(128, 256, 512, 1024), + num_blocks=(3, 3, 5, 2), + expand_ratio=4, + bottle_ratio=2, + has_downsampler=(False, True, True, True), + down_growth=False), + } + + def __init__(self, + depth, + in_channels=3, + out_indices=(3, ), + frozen_stages=-1, + deep_stem=False, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict(type='Kaiming', layer='Conv2d')): + assert depth in self.arch_settings, 'depth must be one of ' \ + f'{list(self.arch_settings.keys())}, but get {depth}.' + self.deep_stem = deep_stem + + super().__init__( + arch=self.arch_settings[depth], + stem_fn=self._make_stem_layer, + in_channels=in_channels, + out_indices=out_indices, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + norm_eval=norm_eval, + init_cfg=init_cfg) + + def _make_stem_layer(self, in_channels): + # `stem_channels` equals to the `in_channels` in the first stage. + stem_channels = self.arch['in_channels'][0] + if self.deep_stem: + stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + else: + stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + return stem + + +@MODELS.register_module() +class CSPResNeXt(CSPResNet): + """CSP-ResNeXt backbone. + + Args: + depth (int): Depth of CSP-ResNeXt. Default: 50. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmpretrain.models import CSPResNeXt + >>> import torch + >>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + arch_settings = { + 50: + dict( + block_fn=ResNeXtBottleneck, + in_channels=(64, 256, 512, 1024), + out_channels=(256, 512, 1024, 2048), + num_blocks=(3, 3, 5, 2), + expand_ratio=(4, 2, 2, 2), + bottle_ratio=4, + has_downsampler=(False, True, True, True), + down_growth=False, + # the base_channels is changed from 64 to 32 in CSPNet + block_args=dict(base_channels=32), + ), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/mmpretrain/models/backbones/davit.py b/mmpretrain/models/backbones/davit.py new file mode 100644 index 0000000000000000000000000000000000000000..cf25e2ed7137fb403e38801b50b355c4306331d6 --- /dev/null +++ b/mmpretrain/models/backbones/davit.py @@ -0,0 +1,834 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn.bricks import Conv2d +from mmcv.cnn.bricks.transformer import FFN, AdaptivePadding, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils import to_2tuple +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import ShiftWindowMSA + + +class DaViTWindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module for DaViT. + + The differences between DaViTWindowMSA & WindowMSA: + 1. Without relative position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ConvPosEnc(BaseModule): + """DaViT conv pos encode block. + + Args: + embed_dims (int): Number of input channels. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, embed_dims, kernel_size=3, init_cfg=None): + super(ConvPosEnc, self).__init__(init_cfg) + self.proj = Conv2d( + embed_dims, + embed_dims, + kernel_size, + stride=1, + padding=kernel_size // 2, + groups=embed_dims) + + def forward(self, x, size: Tuple[int, int]): + B, N, C = x.shape + H, W = size + assert N == H * W + + feat = x.transpose(1, 2).view(B, C, H, W) + feat = self.proj(feat) + feat = feat.flatten(2).transpose(1, 2) + x = x + feat + return x + + +class DaViTDownSample(BaseModule): + """DaViT down sampole block. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_type (str): The type of convolution + to generate patch embedding. Default: "Conv2d". + kernel_size (int): The kernel size of the first convolution. + Defaults to 2. + stride (int): The stride of the second convluation module. + Defaults to 2. + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Defaults to "corner". + dilation (int): Dilation of the convolution layers. Defaults to 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_type='Conv2d', + kernel_size=2, + stride=2, + padding='same', + dilation=1, + bias=True, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.out_channels = out_channels + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adaptive_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.norm = None + + def forward(self, x, input_size): + if self.adaptive_padding: + x = self.adaptive_padding(x) + H, W = input_size + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = self.norm(x) + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + + x = self.projection(x) + output_size = (x.size(2), x.size(3)) + x = x.flatten(2).transpose(1, 2) + return x, output_size + + +class ChannelAttention(BaseModule): + """DaViT channel attention. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, embed_dims, num_heads=8, qkv_bias=False, init_cfg=None): + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = self.head_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dims, embed_dims) + + def forward(self, x): + B, N, _ = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + x = self.proj(x) + return x + + +class ChannelBlock(BaseModule): + """DaViT channel attention block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + ffn_ratio=4., + qkv_bias=False, + drop_path=0., + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) + self.with_cp = with_cp + + self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ChannelAttention( + embed_dims, num_heads=num_heads, qkv_bias=qkv_bias) + self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.cpe1(x, hw_shape) + identity = x + x = self.norm1(x) + x = self.attn(x) + x = x + identity + + x = self.cpe2(x, hw_shape) + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SpatialBlock(BaseModule): + """DaViT spatial attention block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SpatialBlock, self).__init__(init_cfg) + self.with_cp = with_cp + + self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'qkv_bias': qkv_bias, + 'pad_small_map': pad_small_map, + 'window_msa': DaViTWindowMSA, + **attn_cfgs + } + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.cpe1(x, hw_shape) + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + x = x + identity + + x = self.cpe2(x, hw_shape) + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class DaViTBlock(BaseModule): + """DaViT block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(DaViTBlock, self).__init__(init_cfg) + self.spatial_block = SpatialBlock( + embed_dims, + num_heads, + window_size=window_size, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path, + pad_small_map=pad_small_map, + attn_cfgs=attn_cfgs, + ffn_cfgs=ffn_cfgs, + norm_cfg=norm_cfg, + with_cp=with_cp) + self.channel_block = ChannelBlock( + embed_dims, + num_heads, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path, + ffn_cfgs=ffn_cfgs, + norm_cfg=norm_cfg, + with_cp=False) + + def forward(self, x, hw_shape): + x = self.spatial_block(x, hw_shape) + x = self.channel_block(x, hw_shape) + + return x + + +class DaViTBlockSequence(BaseModule): + """Module with successive DaViT blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive DaViT blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + self.embed_dims = embed_dims + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'window_size': window_size, + 'ffn_ratio': ffn_ratio, + 'qkv_bias': qkv_bias, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **block_cfgs[i] + } + block = DaViTBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': 2 * embed_dims, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = DaViTDownSample(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x, in_shape, do_downsample=True): + for block in self.blocks: + x = block(x, in_shape) + + if self.downsample is not None and do_downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + return x, out_shape + + @property + def out_channels(self): + if self.downsample: + return self.downsample.out_channels + else: + return self.embed_dims + + +@MODELS.register_module() +class DaViT(BaseBackbone): + """DaViT. + + A PyTorch implement of : `DaViT: Dual Attention Vision Transformers + `_ + + Inspiration from + https://github.com/dingmyu/davit + + Args: + arch (str | dict): DaViT architecture. If use string, choose from + 'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 't'. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], { + 'embed_dims': 96, + 'depths': [1, 1, 3, 1], + 'num_heads': [3, 6, 12, 24] + }), + **dict.fromkeys(['s', 'small'], { + 'embed_dims': 96, + 'depths': [1, 1, 9, 1], + 'num_heads': [3, 6, 12, 24] + }), + **dict.fromkeys(['b', 'base'], { + 'embed_dims': 128, + 'depths': [1, 1, 9, 1], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [1, 1, 9, 1], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 256, + 'depths': [1, 1, 9, 1], + 'num_heads': [8, 16, 32, 64] + }), + **dict.fromkeys( + ['g', 'giant'], { + 'embed_dims': 384, + 'depths': [1, 1, 12, 3], + 'num_heads': [12, 24, 48, 96] + }), + } + + def __init__(self, + arch='t', + patch_size=4, + in_channels=3, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + out_after_downsample=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + frozen_stages=-1, + norm_eval=False, + out_indices=(3, ), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.out_after_downsample = out_after_downsample + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + # stochastic depth decay rule + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + _patch_cfg = dict( + in_channels=in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=7, + stride=patch_size, + padding='same', + norm_cfg=dict(type='LN'), + ) + self.patch_embed = PatchEmbed(**_patch_cfg) + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': window_size, + 'ffn_ratio': ffn_ratio, + 'qkv_bias': qkv_bias, + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **stage_cfg + } + + stage = DaViTBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + self.num_features = embed_dims[:-1] + + # add a norm layer for each output + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/deit.py b/mmpretrain/models/backbones/deit.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae340829bece31536d0c0ac119ffe635bce82e0 --- /dev/null +++ b/mmpretrain/models/backbones/deit.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .vision_transformer import VisionTransformer + + +@MODELS.register_module() +class DistilledVisionTransformer(VisionTransformer): + """Distilled Vision Transformer. + + A PyTorch implement of : `Training data-efficient image transformers & + distillation through attention `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'deit-base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: A tuple with the class token and the + distillation token. The shapes of both tensor are (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + num_extra_tokens = 2 # class token and distillation token + + def __init__(self, arch='deit-base', *args, **kwargs): + super(DistilledVisionTransformer, self).__init__( + arch=arch, + with_cls_token=True, + *args, + **kwargs, + ) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + x = x + self.resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'cls_token': + return x[:, 0], x[:, 1] + + return super()._format_output(x, hw) + + def init_weights(self): + super(DistilledVisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.dist_token, std=0.02) diff --git a/mmpretrain/models/backbones/deit3.py b/mmpretrain/models/backbones/deit3.py new file mode 100644 index 0000000000000000000000000000000000000000..acedabe42d66a8073f34b1b0ae87501522fcc1b5 --- /dev/null +++ b/mmpretrain/models/backbones/deit3.py @@ -0,0 +1,454 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +from mmcv.cnn import Linear, build_activation_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils import deprecated_api_warning +from torch import nn + +from mmpretrain.registry import MODELS +from ..utils import (LayerScale, MultiheadAttention, build_norm_layer, + resize_pos_embed, to_2tuple) +from .vision_transformer import VisionTransformer + + +class DeiT3FFN(BaseModule): + """FFN for DeiT3. + + The differences between DeiT3FFN & FFN: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3FFN. Defaults to True. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning( + { + 'dropout': 'ffn_drop', + 'add_residual': 'add_identity' + }, + cls_name='FFN') + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0., + dropout_layer=None, + add_identity=True, + use_layer_scale=True, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + if use_layer_scale: + self.gamma2 = LayerScale(embed_dims) + else: + self.gamma2 = nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class DeiT3TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in DeiT3. + + The differences between DeiT3TransformerEncoderLayer & + TransformerEncoderLayer: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3TransformerEncoderLayer. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + use_layer_scale=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + use_layer_scale=use_layer_scale) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = DeiT3FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + use_layer_scale=use_layer_scale) + + def init_weights(self): + super(DeiT3TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln1(x), identity=x) + return x + + +@MODELS.register_module() +class DeiT3(VisionTransformer): + """DeiT3 backbone. + + A PyTorch implement of : `DeiT III: Revenge of the ViT + `_ + + The differences between DeiT3 & VisionTransformer: + + 1. Use LayerScale. + 2. Concat cls token after adding pos_embed. + + Args: + arch (str | dict): DeiT3 architecture. If use string, + choose from 'small', 'base', 'medium', 'large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in DeiT3. + Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 1536, + }), + **dict.fromkeys( + ['m', 'medium'], { + 'embed_dims': 512, + 'num_layers': 12, + 'num_heads': 8, + 'feedforward_channels': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + } + num_extra_tokens = 1 # class token + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + use_layer_scale=True, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + use_layer_scale=use_layer_scale) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg)) + + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=0) + x = self.drop_after_pos(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1]))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed( + state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + num_extra_tokens=0, # The cls token adding is after pos_embed + ) diff --git a/mmpretrain/models/backbones/densenet.py b/mmpretrain/models/backbones/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f05302f9b84cd38c7c03701fc21ffd109c1620 --- /dev/null +++ b/mmpretrain/models/backbones/densenet.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import build_activation_layer, build_norm_layer +from torch.jit.annotations import List + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class DenseLayer(BaseBackbone): + """DenseBlock layers.""" + + def __init__(self, + in_channels, + growth_rate, + bn_size, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseLayer, self).__init__() + + self.norm1 = build_norm_layer(norm_cfg, in_channels)[1] + self.conv1 = nn.Conv2d( + in_channels, + bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False) + self.act = build_activation_layer(act_cfg) + self.norm2 = build_norm_layer(norm_cfg, bn_size * growth_rate)[1] + self.conv2 = nn.Conv2d( + bn_size * growth_rate, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1( + self.act(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + # This decorator indicates to the compiler that a function or method + # should be ignored and replaced with the raising of an exception. + # Here this function is incompatible with torchscript. + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + # Here use torch.utils.checkpoint to rerun a forward-pass during + # backward in bottleneck to save memories. + return cp.checkpoint(closure, *x) + + def forward(self, x): # noqa: F811 + # type: (List[torch.Tensor]) -> torch.Tensor + # assert input features is a list of Tensor + assert isinstance(x, list) + + if self.memory_efficient and self.any_requires_grad(x): + if torch.jit.is_scripting(): + raise Exception('Memory Efficient not supported in JIT') + bottleneck_output = self.call_checkpoint_bottleneck(x) + else: + bottleneck_output = self.bottleneck_fn(x) + + new_features = self.conv2(self.act(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout( + new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.Module): + """DenseNet Blocks.""" + + def __init__(self, + num_layers, + in_channels, + bn_size, + growth_rate, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseBlock, self).__init__() + self.block = nn.ModuleList([ + DenseLayer( + in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) for i in range(num_layers) + ]) + + def forward(self, init_features): + features = [init_features] + for layer in self.block: + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + """DenseNet Transition Layers.""" + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')): + super(DenseTransition, self).__init__() + self.add_module('norm', build_norm_layer(norm_cfg, in_channels)[1]) + self.add_module('act', build_activation_layer(act_cfg)) + self.add_module( + 'conv', + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, + bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +@MODELS.register_module() +class DenseNet(BaseBackbone): + """DenseNet. + + A PyTorch implementation of : `Densely Connected Convolutional Networks + `_ + + Modified from the `official repo + `_ + and `pytorch + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``DenseNet.arch_settings``. And if dict, it + should include the following two keys: + + - growth_rate (int): Each layer of DenseBlock produce `k` feature + maps. Here refers `k` as the growth rate of the network. + - depths (list[int]): Number of repeated layers in each DenseBlock. + - init_channels (int): The output channels of stem layers. + + Defaults to '121'. + in_channels (int): Number of input image channels. Defaults to 3. + bn_size (int): Refers to channel expansion parameter of 1x1 + convolution layer. Defaults to 4. + drop_rate (float): Drop rate of Dropout Layer. Defaults to 0. + compression_factor (float): The reduction rate of transition layers. + Defaults to 0.5. + memory_efficient (bool): If True, uses checkpointing. Much more memory + efficient, but slower. Defaults to False. + See `"paper" `_. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='ReLU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '121': { + 'growth_rate': 32, + 'depths': [6, 12, 24, 16], + 'init_channels': 64, + }, + '169': { + 'growth_rate': 32, + 'depths': [6, 12, 32, 32], + 'init_channels': 64, + }, + '201': { + 'growth_rate': 32, + 'depths': [6, 12, 48, 32], + 'init_channels': 64, + }, + '161': { + 'growth_rate': 48, + 'depths': [6, 12, 36, 24], + 'init_channels': 96, + }, + } + + def __init__(self, + arch='121', + in_channels=3, + bn_size=4, + drop_rate=0, + compression_factor=0.5, + memory_efficient=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'growth_rate', 'depths', 'init_channels'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.growth_rate = arch['growth_rate'] + self.depths = arch['depths'] + self.init_channels = arch['init_channels'] + self.act = build_activation_layer(act_cfg) + + self.num_stages = len(self.depths) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.init_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False), + build_norm_layer(norm_cfg, self.init_channels)[1], self.act, + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Repetitions of DenseNet Blocks + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + + channels = self.init_channels + for i in range(self.num_stages): + depth = self.depths[i] + + stage = DenseBlock( + num_layers=depth, + in_channels=channels, + bn_size=bn_size, + growth_rate=self.growth_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) + self.stages.append(stage) + channels += depth * self.growth_rate + + if i != self.num_stages - 1: + transition = DenseTransition( + in_channels=channels, + out_channels=math.floor(channels * compression_factor), + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + channels = math.floor(channels * compression_factor) + else: + # Final layers after dense block is just bn with act. + # Unlike the paper, the original repo also put this in + # transition layer, whereas torchvision take this out. + # We reckon this as transition layer here. + transition = nn.Sequential( + build_norm_layer(norm_cfg, channels)[1], + self.act, + ) + self.transitions.append(transition) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i in range(self.num_stages): + x = self.stages[i](x) + x = self.transitions[i](x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.transitions[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(DenseNet, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/edgenext.py b/mmpretrain/models/backbones/edgenext.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4e768e7561eb49da3603f4394faaebed7c9251 --- /dev/null +++ b/mmpretrain/models/backbones/edgenext.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier, + build_norm_layer) +from .base_backbone import BaseBackbone +from .convnext import ConvNeXtBlock + + +class SDTAEncoder(BaseModule): + """A PyTorch implementation of split depth-wise transpose attention (SDTA) + encoder. + + Inspiration from + https://github.com/mmaaz60/EdgeNeXt + Args: + in_channel (int): Number of input channels. + drop_path_rate (float): Stochastic depth dropout rate. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + mlp_ratio (int): Number of channels ratio in the MLP. + Defaults to 4. + use_pos_emb (bool): Whether to use position encoding. + Defaults to True. + num_heads (int): Number of heads in the multihead attention. + Defaults to 8. + qkv_bias (bool): Whether to use bias in the multihead attention. + Defaults to True. + attn_drop (float): Dropout rate of the attention. + Defaults to 0. + proj_drop (float): Dropout rate of the projection. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + norm_cfg (dict): Dictionary to construct normalization layer. + Defaults to ``dict(type='LN')``. + act_cfg (dict): Dictionary to construct activation layer. + Defaults to ``dict(type='GELU')``. + scales (int): Number of scales. Default to 1. + """ + + def __init__(self, + in_channel, + drop_path_rate=0., + layer_scale_init_value=1e-6, + mlp_ratio=4, + use_pos_emb=True, + num_heads=8, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + scales=1, + init_cfg=None): + super(SDTAEncoder, self).__init__(init_cfg=init_cfg) + conv_channels = max( + int(math.ceil(in_channel / scales)), + int(math.floor(in_channel // scales))) + self.conv_channels = conv_channels + self.num_convs = scales if scales == 1 else scales - 1 + + self.conv_modules = ModuleList() + for i in range(self.num_convs): + self.conv_modules.append( + nn.Conv2d( + conv_channels, + conv_channels, + kernel_size=3, + padding=1, + groups=conv_channels)) + + self.pos_embed = PositionEncodingFourier( + embed_dims=in_channel) if use_pos_emb else None + + self.norm_csa = build_norm_layer(norm_cfg, in_channel) + self.gamma_csa = nn.Parameter( + layer_scale_init_value * torch.ones(in_channel), + requires_grad=True) if layer_scale_init_value > 0 else None + self.csa = ChannelMultiheadAttention( + embed_dims=in_channel, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop) + + self.norm = build_norm_layer(norm_cfg, in_channel) + self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel) + self.act = MODELS.build(act_cfg) + self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones(in_channel), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + spx = torch.split(x, self.conv_channels, dim=1) + for i in range(self.num_convs): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.conv_modules[i](sp) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + x = torch.cat((out, spx[self.num_convs]), 1) + + # Channel Self-attention + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embed: + pos_encoding = self.pos_embed((B, H, W)) + pos_encoding = pos_encoding.reshape(B, -1, + x.shape[1]).permute(0, 2, 1) + x += pos_encoding + + x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.pointwise_conv1(x) + x = self.act(x) + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + + x = shortcut + self.drop_path(x) + + return x + + +@MODELS.register_module() +class EdgeNeXt(BaseBackbone): + """EdgeNeXt. + + A PyTorch implementation of: `EdgeNeXt: Efficiently Amalgamated + CNN-Transformer Architecture for Mobile Vision Applications + `_ + + Inspiration from + https://github.com/mmaaz60/EdgeNeXt + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architectures in ``EdgeNeXt.arch_settings``. + And if dict, it should include the following keys: + + - channels (list[int]): The number of channels at each stage. + - depths (list[int]): The number of blocks at each stage. + - num_heads (list[int]): The number of heads at each stage. + + Defaults to 'xxsmall'. + in_channels (int): The number of input channels. + Defaults to 3. + global_blocks (list[int]): The number of global blocks. + Defaults to [0, 1, 1, 1]. + global_block_type (list[str]): The type of global blocks. + Defaults to ['None', 'SDTA', 'SDTA', 'SDTA']. + drop_path_rate (float): Stochastic depth dropout rate. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to False. + mlp_ratio (int): The number of channel ratio in MLP layers. + Defaults to 4. + conv_kernel_size (list[int]): The kernel size of convolutional layers + at each stage. Defaults to [3, 5, 7, 9]. + use_pos_embd_csa (list[bool]): Whether to use positional embedding in + Channel Self-Attention. Defaults to [False, True, False, False]. + use_pos_emebd_global (bool): Whether to use positional embedding for + whole network. Defaults to False. + d2_scales (list[int]): The number of channel groups used for SDTA at + each stage. Defaults to [2, 2, 3, 4]. + norm_cfg (dict): The config of normalization layer. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. Defaults to True. + act_cfg (dict): The config of activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict, optional): Config for initialization. + Defaults to None. + """ + arch_settings = { + 'xxsmall': { # parameters: 1.3M + 'channels': [24, 48, 88, 168], + 'depths': [2, 2, 6, 2], + 'num_heads': [4, 4, 4, 4] + }, + 'xsmall': { # parameters: 2.3M + 'channels': [32, 64, 100, 192], + 'depths': [3, 3, 9, 3], + 'num_heads': [4, 4, 4, 4] + }, + 'small': { # parameters: 5.6M + 'channels': [48, 96, 160, 304], + 'depths': [3, 3, 9, 3], + 'num_heads': [8, 8, 8, 8] + }, + 'base': { # parameters: 18.51M + 'channels': [80, 160, 288, 584], + 'depths': [3, 3, 9, 3], + 'num_heads': [8, 8, 8, 8] + }, + } + + def __init__(self, + arch='xxsmall', + in_channels=3, + global_blocks=[0, 1, 1, 1], + global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], + drop_path_rate=0., + layer_scale_init_value=1e-6, + linear_pw_conv=True, + mlp_ratio=4, + conv_kernel_sizes=[3, 5, 7, 9], + use_pos_embd_csa=[False, True, False, False], + use_pos_embd_global=False, + d2_scales=[2, 2, 3, 4], + norm_cfg=dict(type='LN2d', eps=1e-6), + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + act_cfg=dict(type='GELU'), + init_cfg=None): + super(EdgeNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Arch {arch} is not in default archs ' \ + f'{set(self.arch_settings)}' + self.arch_settings = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'channels', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.channels = self.arch_settings['channels'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.use_pos_embd_global = use_pos_embd_global + + for g in global_block_type: + assert g in ['None', + 'SDTA'], f'Global block type {g} is not supported' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + + if self.use_pos_embd_global: + self.pos_embed = PositionEncodingFourier( + embed_dims=self.channels[0]) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + self.stages = ModuleList() + block_idx = 0 + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2, + )) + self.downsample_layers.append(downsample_layer) + + stage_blocks = [] + for j in range(depth): + if j > depth - global_blocks[i] - 1: + stage_blocks.append( + SDTAEncoder( + in_channel=channels, + drop_path_rate=dpr[block_idx + j], + mlp_ratio=mlp_ratio, + scales=d2_scales[i], + use_pos_emb=use_pos_embd_csa[i], + num_heads=self.num_heads[i], + )) + else: + dw_conv_cfg = dict( + kernel_size=conv_kernel_sizes[i], + padding=conv_kernel_sizes[i] // 2, + ) + stage_blocks.append( + ConvNeXtBlock( + in_channels=channels, + dw_conv_cfg=dw_conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + drop_path_rate=dpr[block_idx + j], + layer_scale_init_value=layer_scale_init_value, + )) + block_idx += depth + + stage_blocks = Sequential(*stage_blocks) + self.stages.append(stage_blocks) + + if i in self.out_indices: + out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \ + else norm_cfg + norm_layer = build_norm_layer(out_norm_cfg, channels) + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self) -> None: + # TODO: need to be implemented in the future + return super().init_weights() + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if self.pos_embed and i == 0: + B, _, H, W = x.shape + x += self.pos_embed((B, H, W)) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap.flatten(1))) + else: + # The output of LayerNorm2d may be discontiguous, which + # may cause some problem in the downstream tasks + outs.append(norm_layer(x).contiguous()) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.downsample_layers[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(EdgeNeXt, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/efficientformer.py b/mmpretrain/models/backbones/efficientformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c2525c8faaa745ff5404e91004421f2360dd1c41 --- /dev/null +++ b/mmpretrain/models/backbones/efficientformer.py @@ -0,0 +1,606 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer, + build_norm_layer) +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import LayerScale +from .base_backbone import BaseBackbone +from .poolformer import Pooling + + +class AttentionWithBias(BaseModule): + """Multi-head Attention Module with attention_bias. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + key_dim (int): The dimension of q, k. Defaults to 32. + attn_ratio (float): The dimension of v equals to + ``key_dim * attn_ratio``. Defaults to 4. + resolution (int): The height and width of attention_bias. + Defaults to 7. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + key_dim=32, + attn_ratio=4., + resolution=7, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.attn_ratio = attn_ratio + self.key_dim = key_dim + self.nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + h = self.dh + self.nh_kd * 2 + self.qkv = nn.Linear(embed_dims, h) + self.proj = nn.Linear(self.dh, embed_dims) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + """change the mode of model.""" + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + """forward function. + + Args: + x (tensor): input features with shape of (B, N, C) + """ + B, N, _ = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Flat(nn.Module): + """Flat the input from (B, C, H, W) to (B, H*W, C).""" + + def __init__(self, ): + super().__init__() + + def forward(self, x: torch.Tensor): + x = x.flatten(2).transpose(1, 2) + return x + + +class LinearMlp(BaseModule): + """Mlp implemented with linear. + + The shape of input and output tensor are (B, N, C). + + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = build_activation_layer(act_cfg) + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor with shape (B, N, C). + + Returns: + torch.Tensor: output tensor with shape (B, N, C). + """ + x = self.drop1(self.act(self.fc1(x))) + x = self.drop2(self.fc2(x)) + return x + + +class ConvMlp(BaseModule): + """Mlp implemented with 1*1 convolutions. + + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1] + self.norm2 = build_norm_layer(norm_cfg, out_features)[1] + + self.drop = nn.Dropout(drop) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor with shape (B, C, H, W). + + Returns: + torch.Tensor: output tensor with shape (B, C, H, W). + """ + + x = self.act(self.norm1(self.fc1(x))) + x = self.drop(x) + x = self.norm2(self.fc2(x)) + x = self.drop(x) + return x + + +class Meta3D(BaseModule): + """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape + (B, N, C).""" + + def __init__(self, + dim, + mlp_ratio=4., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + use_layer_scale=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = AttentionWithBias(dim) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = LinearMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + if use_layer_scale: + self.ls1 = LayerScale(dim) + self.ls2 = LayerScale(dim) + else: + self.ls1, self.ls2 = nn.Identity(), nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x)))) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class Meta4D(BaseModule): + """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape + (B, C, H, W).""" + + def __init__(self, + dim, + pool_size=3, + mlp_ratio=4., + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + use_layer_scale=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.token_mixer = Pooling(pool_size=pool_size) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + if use_layer_scale: + self.ls1 = LayerScale(dim, data_format='channels_first') + self.ls2 = LayerScale(dim, data_format='channels_first') + else: + self.ls1, self.ls2 = nn.Identity(), nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(x))) + x = x + self.drop_path(self.ls2(self.mlp(x))) + return x + + +def basic_blocks(in_channels, + out_channels, + index, + layers, + pool_size=3, + mlp_ratio=4., + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + use_layer_scale=True, + vit_num=1, + has_downsamper=False): + """generate EfficientFormer blocks for a stage.""" + blocks = [] + if has_downsamper: + blocks.append( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + norm_cfg=dict(type='BN'), + act_cfg=None)) + if index == 3 and vit_num == layers[index]: + blocks.append(Flat()) + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + if index == 3 and layers[index] - block_idx <= vit_num: + blocks.append( + Meta3D( + out_channels, + mlp_ratio=mlp_ratio, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + )) + else: + blocks.append( + Meta4D( + out_channels, + pool_size=pool_size, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale)) + if index == 3 and layers[index] - block_idx - 1 == vit_num: + blocks.append(Flat()) + blocks = nn.Sequential(*blocks) + return blocks + + +@MODELS.register_module() +class EfficientFormer(BaseBackbone): + """EfficientFormer. + + A PyTorch implementation of EfficientFormer introduced by: + `EfficientFormer: Vision Transformers at MobileNet Speed `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``EfficientFormer.arch_settings``. And if dict, + it should include the following 4 keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - downsamples (list[int]): Has downsample or not in the four stages. + - vit_num (int): The num of vit blocks in the last stage. + + Defaults to 'l1'. + + in_channels (int): The num of input channels. Defaults to 3. + pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3. + mlp_ratios (int): The dimension ratio of multi-head attention mechanism + in ``Meta4D`` blocks. Defaults to 3. + reshape_last_feat (bool): Whether to reshape the feature map from + (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num`` + in ``arch`` is not 0. Defaults to False. Usually set to True + in downstream tasks. + out_indices (Sequence[int]): Output from which stages. + Defaults to -1. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer + block. Defaults to True. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmpretrain.models import EfficientFormer + >>> import torch + >>> inputs = torch.rand((1, 3, 224, 224)) + >>> # build EfficientFormer backbone for classification task + >>> model = EfficientFormer(arch="l1") + >>> model.eval() + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 448, 49) + >>> # build EfficientFormer backbone for downstream task + >>> model = EfficientFormer( + >>> arch="l3", + >>> out_indices=(0, 1, 2, 3), + >>> reshape_last_feat=True) + >>> model.eval() + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 320, 14, 14) + (1, 512, 7, 7) + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims: [x,x,x,x], embedding dims for the four stages + # --downsamples: [x,x,x,x], has downsample or not in the four stages + # --vit_num:(int), the num of vit blocks in the last stage + arch_settings = { + 'l1': { + 'layers': [3, 2, 6, 4], + 'embed_dims': [48, 96, 224, 448], + 'downsamples': [False, True, True, True], + 'vit_num': 1, + }, + 'l3': { + 'layers': [4, 4, 12, 6], + 'embed_dims': [64, 128, 320, 512], + 'downsamples': [False, True, True, True], + 'vit_num': 4, + }, + 'l7': { + 'layers': [6, 6, 18, 8], + 'embed_dims': [96, 192, 384, 768], + 'downsamples': [False, True, True, True], + 'vit_num': 8, + }, + } + + def __init__(self, + arch='l1', + in_channels=3, + pool_size=3, + mlp_ratios=4, + reshape_last_feat=False, + out_indices=-1, + frozen_stages=-1, + act_cfg=dict(type='GELU'), + drop_rate=0., + drop_path_rate=0., + use_layer_scale=True, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.num_extra_tokens = 0 # no cls_token, no dist_token + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + default_keys = set(self.arch_settings['l1'].keys()) + assert set(arch.keys()) == default_keys, \ + f'The arch dict must have {default_keys}, ' \ + f'but got {list(arch.keys())}.' + + self.layers = arch['layers'] + self.embed_dims = arch['embed_dims'] + self.downsamples = arch['downsamples'] + assert isinstance(self.layers, list) and isinstance( + self.embed_dims, list) and isinstance(self.downsamples, list) + assert len(self.layers) == len(self.embed_dims) == len( + self.downsamples) + + self.vit_num = arch['vit_num'] + self.reshape_last_feat = reshape_last_feat + + assert self.vit_num >= 0, "'vit_num' must be an integer " \ + 'greater than or equal to 0.' + assert self.vit_num <= self.layers[-1], ( + "'vit_num' must be an integer smaller than layer number") + + self._make_stem(in_channels, self.embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(self.layers)): + if i != 0: + in_channels = self.embed_dims[i - 1] + else: + in_channels = self.embed_dims[i] + out_channels = self.embed_dims[i] + stage = basic_blocks( + in_channels, + out_channels, + i, + self.layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + vit_num=self.vit_num, + use_layer_scale=use_layer_scale, + has_downsamper=self.downsamples[i]) + network.append(stage) + + self.network = ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + + self.out_indices = out_indices + for i_layer in self.out_indices: + if not self.reshape_last_feat and \ + i_layer == 3 and self.vit_num > 0: + layer = build_norm_layer( + dict(type='LN'), self.embed_dims[i_layer])[1] + else: + # use GN with 1 group as channel-first LN2D + layer = build_norm_layer( + dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1] + + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def _make_stem(self, in_channels: int, stem_channels: int): + """make 2-ConvBNReLu stem layer.""" + self.patch_embed = Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + inplace=True)) + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + if idx == len(self.network) - 1: + N, _, H, W = x.shape + if self.downsamples[idx]: + H, W = H // 2, W // 2 + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + + if idx == len(self.network) - 1 and x.dim() == 3: + # when ``vit-num`` > 0 and in the last stage, + # if `self.reshape_last_feat`` is True, reshape the + # features to `BCHW` format before the final normalization. + # if `self.reshape_last_feat`` is False, do + # normalization directly and permute the features to `BCN`. + if self.reshape_last_feat: + x = x.permute((0, 2, 1)).reshape(N, -1, H, W) + x_out = norm_layer(x) + else: + x_out = norm_layer(x).permute((0, 2, 1)) + else: + x_out = norm_layer(x) + + outs.append(x_out.contiguous()) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.patch_embed(x) + # through stages + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientFormer, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/efficientnet.py b/mmpretrain/models/backbones/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec7ee81186610f7adb8af92325471d794509ddc --- /dev/null +++ b/mmpretrain/models/backbones/efficientnet.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import BaseModule, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible +from mmpretrain.registry import MODELS + + +class EdgeResidual(BaseModule): + """Edge Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the second convolution. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + stride (int): The stride of the first convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + with_residual (bool): Use residual connection. Defaults to True. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_residual=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(EdgeResidual, self).__init__(init_cfg=init_cfg) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_residual = ( + stride == 1 and in_channels == out_channels and with_residual) + + if self.with_se: + assert isinstance(se_cfg, dict) + + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.conv2 = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + out = self.conv1(out) + + if self.with_se: + out = self.se(out) + + out = self.conv2(out) + + if self.with_residual: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +def model_scaling(layer_setting, arch_setting): + """Scaling operation to the layer's parameters according to the + arch_setting.""" + # scale width + new_layer_setting = copy.deepcopy(layer_setting) + for layer_cfg in new_layer_setting: + for block_cfg in layer_cfg: + block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8) + + # scale depth + split_layer_setting = [new_layer_setting[0]] + for layer_cfg in new_layer_setting[1:-1]: + tmp_index = [0] + for i in range(len(layer_cfg) - 1): + if layer_cfg[i + 1][1] != layer_cfg[i][1]: + tmp_index.append(i + 1) + tmp_index.append(len(layer_cfg)) + for i in range(len(tmp_index) - 1): + split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i + + 1]]) + split_layer_setting.append(new_layer_setting[-1]) + + num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]] + new_layers = [ + int(math.ceil(arch_setting[1] * num)) for num in num_of_layers + ] + + merge_layer_setting = [split_layer_setting[0]] + for i, layer_cfg in enumerate(split_layer_setting[1:-1]): + if new_layers[i] <= num_of_layers[i]: + tmp_layer_cfg = layer_cfg[:new_layers[i]] + else: + tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * ( + new_layers[i] - num_of_layers[i]) + if tmp_layer_cfg[0][3] == 1 and i != 0: + merge_layer_setting[-1] += tmp_layer_cfg.copy() + else: + merge_layer_setting.append(tmp_layer_cfg.copy()) + merge_layer_setting.append(split_layer_setting[-1]) + + return merge_layer_setting + + +@MODELS.register_module() +class EfficientNet(BaseBackbone): + """EfficientNet backbone. + + Args: + arch (str): Architecture of efficientnet. Defaults to b0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (6, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. + # 'b' represents the architecture of normal EfficientNet family includes + # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'. + # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es', + # 'em', 'el'. + # 6 parameters are needed to construct a layer, From left to right: + # - kernel_size: The kernel size of the block + # - out_channel: The number of out_channels of the block + # - se_ratio: The sequeeze ratio of SELayer. + # - stride: The stride of the block + # - expand_ratio: The expand_ratio of the mid_channels + # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual + layer_settings = { + 'b': [[[3, 32, 0, 2, 0, -1]], + [[3, 16, 4, 1, 1, 0]], + [[3, 24, 4, 2, 6, 0], + [3, 24, 4, 1, 6, 0]], + [[5, 40, 4, 2, 6, 0], + [5, 40, 4, 1, 6, 0]], + [[3, 80, 4, 2, 6, 0], + [3, 80, 4, 1, 6, 0], + [3, 80, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0]], + [[5, 192, 4, 2, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [3, 320, 4, 1, 6, 0]], + [[1, 1280, 0, 1, 0, -1]] + ], + 'e': [[[3, 32, 0, 2, 0, -1]], + [[3, 24, 0, 1, 3, 1]], + [[3, 32, 0, 2, 8, 1], + [3, 32, 0, 1, 8, 1]], + [[3, 48, 0, 2, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1]], + [[5, 96, 0, 2, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0]], + [[5, 192, 0, 2, 8, 0], + [5, 192, 0, 1, 8, 0]], + [[1, 1280, 0, 1, 0, -1]] + ] + } # yapf: disable + + # Parameters to build different kinds of architecture. + # From left to right: scaling factor for width, scaling factor for depth, + # resolution. + arch_settings = { + 'b0': (1.0, 1.0, 224), + 'b1': (1.0, 1.1, 240), + 'b2': (1.1, 1.2, 260), + 'b3': (1.2, 1.4, 300), + 'b4': (1.4, 1.8, 380), + 'b5': (1.6, 2.2, 456), + 'b6': (1.8, 2.6, 528), + 'b7': (2.0, 3.1, 600), + 'b8': (2.2, 3.6, 672), + 'l2': (4.3, 5.3, 800), + 'es': (1.0, 1.0, 224), + 'em': (1.0, 1.1, 240), + 'el': (1.2, 1.4, 300) + } + + def __init__(self, + arch='b0', + drop_path_rate=0., + out_indices=(6, ), + frozen_stages=0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='Swish'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNet, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch_setting = self.arch_settings[arch] + # layer_settings of arch='l2' is 'b' + self.layer_setting = self.layer_settings['b' if arch == + 'l2' else arch[:1]] + for index in out_indices: + if index not in range(0, len(self.layer_setting)): + raise ValueError('the item in out_indices must in ' + f'range(0, {len(self.layer_setting)}). ' + f'But received {index}') + + if frozen_stages not in range(len(self.layer_setting) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.layer_setting) + 1}). ' + f'But received {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layer_setting = model_scaling(self.layer_setting, + self.arch_setting) + block_cfg_0 = self.layer_setting[0][0] + block_cfg_last = self.layer_setting[-1][0] + self.in_channels = make_divisible(block_cfg_0[1], 8) + self.out_channels = block_cfg_last[1] + self.layers = nn.ModuleList() + self.layers.append( + ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=block_cfg_0[0], + stride=block_cfg_0[3], + padding=block_cfg_0[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.make_layer() + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=block_cfg_last[0], + stride=block_cfg_last[3], + padding=block_cfg_last[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def make_layer(self): + # Without the first and the final conv block. + layer_setting = self.layer_setting[1:-1] + + total_num_blocks = sum([len(x) for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for layer_cfg in layer_setting: + layer = [] + for i, block_cfg in enumerate(layer_cfg): + (kernel_size, out_channels, se_ratio, stride, expand_ratio, + block_type) = block_cfg + + mid_channels = int(self.in_channels * expand_ratio) + out_channels = make_divisible(out_channels, 8) + if se_ratio <= 0: + se_cfg = None + else: + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * se_ratio, + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + if block_type == 1: # edge tpu + if i > 0 and expand_ratio == 3: + with_residual = False + expand_ratio = 4 + else: + with_residual = True + mid_channels = int(self.in_channels * expand_ratio) + if se_cfg is not None: + se_cfg = dict( + channels=mid_channels, + ratio=se_ratio * expand_ratio, + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = partial(EdgeResidual, with_residual=with_residual) + else: + block = InvertedResidual + layer.append( + block( + in_channels=self.in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp)) + self.in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + def forward(self, x): + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/efficientnet_v2.py b/mmpretrain/models/backbones/efficientnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..fec002a4dac46f756f00ed8f596b37028ba18c37 --- /dev/null +++ b/mmpretrain/models/backbones/efficientnet_v2.py @@ -0,0 +1,343 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import Sequential +from torch import Tensor + +from mmpretrain.registry import MODELS +from ..utils import InvertedResidual as MBConv +from .base_backbone import BaseBackbone +from .efficientnet import EdgeResidual as FusedMBConv + + +class EnhancedConvModule(ConvModule): + """ConvModule with short-cut and droppath. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + has_skip (bool): Whether there is short-cut. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + def __init__(self, *args, has_skip=False, drop_path_rate=0, **kwargs): + super().__init__(*args, **kwargs) + self.has_skip = has_skip + if self.has_skip and (self.in_channels != self.out_channels + or self.stride != (1, 1)): + raise ValueError('the stride must be 1 and the `in_channels` and' + ' `out_channels` must be the same , when ' + '`has_skip` is True in `EnhancedConvModule` .') + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate else nn.Identity() + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + short_cut = x + x = super().forward(x, **kwargs) + if self.has_skip: + x = self.drop_path(x) + short_cut + return x + + +@MODELS.register_module() +class EfficientNetV2(BaseBackbone): + """EfficientNetV2 backbone. + + A PyTorch implementation of EfficientNetV2 introduced by: + `EfficientNetV2: Smaller Models and Faster Training + `_ + + Args: + arch (str): Architecture of efficientnetv2. Defaults to s. + in_channels (int): Number of input image channels. Defaults to 3. + drop_path_rate (float): The ratio of the stochastic depth. + Defaults to 0.0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (-1, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. From left to right: + # - repeat (int): The repeat number of the block in the layer + # - kernel_size (int): The kernel size of the layer + # - stride (int): The stride of the first block of the layer + # - expand_ratio (int, float): The expand_ratio of the mid_channels + # - in_channel (int): The number of in_channels of the layer + # - out_channel (int): The number of out_channels of the layer + # - se_ratio (float): The sequeeze ratio of SELayer. + # - block_type (int): -2: ConvModule, -1: EnhancedConvModule, + # 0: FusedMBConv, 1: MBConv + arch_settings = { + **dict.fromkeys(['small', 's'], [[2, 3, 1, 1, 24, 24, 0.0, -1], + [4, 3, 2, 4, 24, 48, 0.0, 0], + [4, 3, 2, 4, 48, 64, 0.0, 0], + [6, 3, 2, 4, 64, 128, 0.25, 1], + [9, 3, 1, 6, 128, 160, 0.25, 1], + [15, 3, 2, 6, 160, 256, 0.25, 1], + [1, 1, 1, 1, 256, 1280, 0.0, -2]]), + **dict.fromkeys(['m', 'medium'], [[3, 3, 1, 1, 24, 24, 0.0, -1], + [5, 3, 2, 4, 24, 48, 0.0, 0], + [5, 3, 2, 4, 48, 80, 0.0, 0], + [7, 3, 2, 4, 80, 160, 0.25, 1], + [14, 3, 1, 6, 160, 176, 0.25, 1], + [18, 3, 2, 6, 176, 304, 0.25, 1], + [5, 3, 1, 6, 304, 512, 0.25, 1], + [1, 1, 1, 1, 512, 1280, 0.0, -2]]), + **dict.fromkeys(['l', 'large'], [[4, 3, 1, 1, 32, 32, 0.0, -1], + [7, 3, 2, 4, 32, 64, 0.0, 0], + [7, 3, 2, 4, 64, 96, 0.0, 0], + [10, 3, 2, 4, 96, 192, 0.25, 1], + [19, 3, 1, 6, 192, 224, 0.25, 1], + [25, 3, 2, 6, 224, 384, 0.25, 1], + [7, 3, 1, 6, 384, 640, 0.25, 1], + [1, 1, 1, 1, 640, 1280, 0.0, -2]]), + **dict.fromkeys(['xl'], [[4, 3, 1, 1, 32, 32, 0.0, -1], + [8, 3, 2, 4, 32, 64, 0.0, 0], + [8, 3, 2, 4, 64, 96, 0.0, 0], + [16, 3, 2, 4, 96, 192, 0.25, 1], + [24, 3, 1, 6, 192, 256, 0.25, 1], + [32, 3, 2, 6, 256, 512, 0.25, 1], + [8, 3, 1, 6, 512, 640, 0.25, 1], + [1, 1, 1, 1, 640, 1280, 0.0, -2]]), + **dict.fromkeys(['b0'], [[1, 3, 1, 1, 32, 16, 0.0, -1], + [2, 3, 2, 4, 16, 32, 0.0, 0], + [2, 3, 2, 4, 32, 48, 0.0, 0], + [3, 3, 2, 4, 48, 96, 0.25, 1], + [5, 3, 1, 6, 96, 112, 0.25, 1], + [8, 3, 2, 6, 112, 192, 0.25, 1], + [1, 1, 1, 1, 192, 1280, 0.0, -2]]), + **dict.fromkeys(['b1'], [[2, 3, 1, 1, 32, 16, 0.0, -1], + [3, 3, 2, 4, 16, 32, 0.0, 0], + [3, 3, 2, 4, 32, 48, 0.0, 0], + [4, 3, 2, 4, 48, 96, 0.25, 1], + [6, 3, 1, 6, 96, 112, 0.25, 1], + [9, 3, 2, 6, 112, 192, 0.25, 1], + [1, 1, 1, 1, 192, 1280, 0.0, -2]]), + **dict.fromkeys(['b2'], [[2, 3, 1, 1, 32, 16, 0.0, -1], + [3, 3, 2, 4, 16, 32, 0.0, 0], + [3, 3, 2, 4, 32, 56, 0.0, 0], + [4, 3, 2, 4, 56, 104, 0.25, 1], + [6, 3, 1, 6, 104, 120, 0.25, 1], + [10, 3, 2, 6, 120, 208, 0.25, 1], + [1, 1, 1, 1, 208, 1408, 0.0, -2]]), + **dict.fromkeys(['b3'], [[2, 3, 1, 1, 40, 16, 0.0, -1], + [3, 3, 2, 4, 16, 40, 0.0, 0], + [3, 3, 2, 4, 40, 56, 0.0, 0], + [5, 3, 2, 4, 56, 112, 0.25, 1], + [7, 3, 1, 6, 112, 136, 0.25, 1], + [12, 3, 2, 6, 136, 232, 0.25, 1], + [1, 1, 1, 1, 232, 1536, 0.0, -2]]) + } + + def __init__(self, + arch: str = 's', + in_channels: int = 3, + drop_path_rate: float = 0., + out_indices: Sequence[int] = (-1, ), + frozen_stages: int = 0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.1), + act_cfg=dict(type='Swish'), + norm_eval: bool = False, + with_cp: bool = False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNetV2, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch = self.arch_settings[arch] + if frozen_stages not in range(len(self.arch) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.arch)}), but get {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layers = nn.ModuleList() + assert self.arch[-1][-1] == -2, \ + f'the last block_type of `arch_setting` must be -2 ,' \ + f'but get `{self.arch[-1][-1]}`' + self.in_channels = in_channels + self.out_channels = self.arch[-1][5] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.make_layers() + + # there len(slef.arch) + 2 layers in the backbone + # including: the first + len(self.arch) layers + the last + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.layers) + index + assert 0 <= out_indices[i] <= len(self.layers), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + def make_layers(self, ): + # make the first layer + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.arch[0][4], + kernel_size=3, + stride=2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + in_channels = self.arch[0][4] + layer_setting = self.arch[:-1] + + total_num_blocks = sum([x[0] for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for layer_cfg in layer_setting: + layer = [] + (repeat, kernel_size, stride, expand_ratio, _, out_channels, + se_ratio, block_type) = layer_cfg + for i in range(repeat): + stride = stride if i == 0 else 1 + if block_type == -1: + has_skip = stride == 1 and in_channels == out_channels + droppath_rate = dpr[block_idx] if has_skip else 0.0 + layer.append( + EnhancedConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + has_skip=has_skip, + drop_path_rate=droppath_rate, + stride=stride, + padding=1, + conv_cfg=None, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channels = out_channels + else: + mid_channels = int(in_channels * expand_ratio) + se_cfg = None + if block_type != 0 and se_ratio > 0: + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * (1.0 / se_ratio), + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = FusedMBConv if block_type == 0 else MBConv + conv_cfg = self.conv_cfg if stride == 2 else None + layer.append( + block( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp)) + in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + # make the last layer + self.layers.append( + ConvModule( + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=self.arch[-1][1], + stride=self.arch[-1][2], + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x: Tensor) -> Tuple[Tensor]: + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/hivit.py b/mmpretrain/models/backbones/hivit.py new file mode 100644 index 0000000000000000000000000000000000000000..981cbf819138ace2c2e8441e7e65f927883c55fd --- /dev/null +++ b/mmpretrain/models/backbones/hivit.py @@ -0,0 +1,656 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer, to_2tuple +from .base_backbone import BaseBackbone + + +class Mlp(nn.Module): + """MLP block. + + Args: + in_features (int): Number of input dims. + hidden_features (int): Number of hidden dims. + out_feature (int): Number of out dims. + act_layer: MLP activation layer. + drop (float): MLP dropout rate. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + """Attention. + + Args: + input size (int): Input size. + dim (int): Number of input dims. + num_heads (int): Number of attention heads. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + rpe (bool): If True, add relative position embedding to + the patch embedding. + """ + + def __init__(self, + input_size, + dim, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + rpe=True): + super().__init__() + self.input_size = input_size + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * input_size - 1) * + (2 * input_size - 1), num_heads)) if rpe else None + if rpe: + coords_h = torch.arange(input_size) + coords_w = torch.arange(input_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += input_size - 1 + relative_coords[:, :, 1] += input_size - 1 + relative_coords[:, :, 0] *= 2 * input_size - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', + relative_position_index) + + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, rpe_index=None, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if rpe_index is not None: + rpe_index = self.relative_position_index.view(-1) + S = int(math.sqrt(rpe_index.size(-1))) + relative_position_bias = self.relative_position_bias_table[ + rpe_index].view(-1, S, S, self.num_heads) + relative_position_bias = relative_position_bias.permute( + 0, 3, 1, 2).contiguous() + attn = attn + relative_position_bias + if mask is not None: + mask = mask.bool() + attn = attn.masked_fill(~mask[:, None, None, :], float('-inf')) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BlockWithRPE(nn.Module): + """HiViT block. + + Args: + input_size (int): Input size. + dim (int): Number of input dims. + num_heads (int): Number of attention heads. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + rpe (bool): If True, add relative position embedding to + the patch embedding. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + act_layer: MLP activation layer. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__(self, + input_size, + dim, + num_heads=0., + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + rpe=True, + layer_scale_init_value=0.0, + act_layer=nn.GELU, + norm_cfg=dict(type='LN')): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + + with_attn = num_heads > 0. + + self.norm1 = build_norm_layer(norm_cfg, dim) if with_attn else None + self.attn = Attention( + input_size, + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + rpe=rpe, + ) if with_attn else None + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones( + (dim)), requires_grad=True) if with_attn else None + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rpe_index=None, mask=None): + if self.attn is not None: + if self.gamma_1 is not None: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), rpe_index, mask)) + else: + x = x + self.drop_path( + self.attn(self.norm1(x), rpe_index, mask)) + if self.gamma_2 is not None: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """PatchEmbed for HiViT. + + Args: + img_size (int): Input image size. + patch_size (int): Patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + in_chans (int): Number of image input channels. + embed_dim (int): Transformer embedding dimension. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + kernel_size (int): Kernel size. + pad_size (int): Pad size. + """ + + def __init__(self, + img_size=224, + patch_size=16, + inner_patches=4, + in_chans=3, + embed_dim=128, + norm_cfg=None, + kernel_size=None, + pad_size=None): + super().__init__() + img_size = to_2tuple(img_size) if not isinstance(img_size, + tuple) else img_size + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.inner_patches = inner_patches + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + conv_size = [size // inner_patches for size in patch_size] + kernel_size = kernel_size or conv_size + pad_size = pad_size or 0 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=conv_size, + padding=pad_size) + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + patches_resolution = (H // self.patch_size[0], W // self.patch_size[1]) + num_patches = patches_resolution[0] * patches_resolution[1] + x = self.proj(x).view( + B, + -1, + patches_resolution[0], + self.inner_patches, + patches_resolution[1], + self.inner_patches, + ).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches, + self.inner_patches, -1) + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchMerge(nn.Module): + """PatchMerge for HiViT. + + Args: + dim (int): Number of input channels. + norm_cfg (dict): Config dict for normalization layer. + """ + + def __init__(self, dim, norm_cfg): + super().__init__() + self.norm = build_norm_layer(norm_cfg, dim * 4) + self.reduction = nn.Linear(dim * 4, dim * 2, bias=False) + + def forward(self, x, *args, **kwargs): + is_main_stage = len(x.shape) == 3 + if is_main_stage: + B, N, C = x.shape + S = int(math.sqrt(N)) + x = x.reshape(B, S // 2, 2, S // 2, 2, C) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(B, -1, 2, 2, C) + x0 = x[..., 0::2, 0::2, :] + x1 = x[..., 1::2, 0::2, :] + x2 = x[..., 0::2, 1::2, :] + x3 = x[..., 1::2, 1::2, :] + + x = torch.cat([x0, x1, x2, x3], dim=-1) + x = self.norm(x) + x = self.reduction(x) + + if is_main_stage: + x = x[:, :, 0, 0, :] + return x + + +@MODELS.register_module() +class HiViT(BaseBackbone): + """HiViT. + + A PyTorch implement of: `HiViT: A Simple and More Efficient Design + of Hierarchical Vision Transformer `_. + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', and'base'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (int): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int): Input image size. + patch_size (int): Patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + in_chans (int): Number of image input channels. + embed_dim (int): Transformer embedding dimension. + depths (list[int]): Number of successive HiViT blocks. + num_heads (int): Number of attention heads. + stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim + in the first two stages. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in + the last stage. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): If True, add absolute position embedding to + the patch embedding. + rpe (bool): If True, add relative position embedding to + the patch embedding. + patch_norm (bool): If True, use norm_cfg for normalization layer. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + kernel_size (int): Kernel size. + pad_size (int): Pad size. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 384, + 'depths': [1, 1, 10], + 'num_heads': 6}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 384, + 'depths': [2, 2, 20], + 'num_heads': 6}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 512, + 'depths': [2, 2, 24], + 'num_heads': 8}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 768, + 'depths': [2, 2, 40], + 'num_heads': 12}), + } # yapf: disable + + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + inner_patches=4, + in_chans=3, + stem_mlp_ratio=3., + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.0, + norm_cfg=dict(type='LN'), + out_indices=[23], + ape=True, + rpe=False, + patch_norm=True, + frozen_stages=-1, + kernel_size=None, + pad_size=None, + layer_scale_init_value=0.0, + init_cfg=None): + super(HiViT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.num_stages = len(self.depths) + self.ape = ape + self.rpe = rpe + self.patch_size = patch_size + self.num_features = self.embed_dims + self.mlp_ratio = mlp_ratio + self.num_main_blocks = self.depths[-1] + self.out_indices = out_indices + self.out_indices[-1] = self.depths[-1] - 1 + + img_size = to_2tuple(img_size) if not isinstance(img_size, + tuple) else img_size + + embed_dim = self.embed_dims // 2**(self.num_stages - 1) + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + in_chans=in_chans, + embed_dim=embed_dim, + norm_cfg=norm_cfg if patch_norm else None, + kernel_size=kernel_size, + pad_size=pad_size) + num_patches = self.patch_embed.num_patches + Hp, Wp = self.patch_embed.patches_resolution + + if rpe: + assert Hp == Wp, 'If you use relative position, make sure H == W ' + 'of input size' + + # absolute position embedding + if ape: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.num_features)) + trunc_normal_(self.pos_embed, std=.02) + if rpe: + # get pair-wise relative position index for each token inside the + # window + coords_h = torch.arange(Hp) + coords_w = torch.arange(Wp) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += Hp - 1 + relative_coords[:, :, 1] += Wp - 1 + relative_coords[:, :, 0] *= 2 * Wp - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', + relative_position_index) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = iter( + x.item() + for x in torch.linspace(0, drop_path_rate, + sum(self.depths) + sum(self.depths[:-1]))) + + # build blocks + self.blocks = nn.ModuleList() + for stage_i, stage_depth in enumerate(self.depths): + is_main_stage = embed_dim == self.num_features + nhead = self.num_heads if is_main_stage else 0 + ratio = mlp_ratio if is_main_stage else stem_mlp_ratio + # every block not in main stage includes two mlp blocks + stage_depth = stage_depth if is_main_stage else stage_depth * 2 + for _ in range(stage_depth): + self.blocks.append( + BlockWithRPE( + Hp, + embed_dim, + nhead, + ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=next(dpr), + rpe=rpe, + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + )) + if stage_i + 1 < self.num_stages: + self.blocks.append(PatchMerge(embed_dim, norm_cfg)) + embed_dim *= 2 + + self.frozen_stages = frozen_stages + if self.frozen_stages > 0: + self._freeze_stages() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, h, w): + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + patch_pos_embed = self.pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), + dim).permute(0, 3, 1, 2), + scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(h0) == patch_pos_embed.shape[-2] and int( + w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, x): + B, C, H, W = x.shape + Hp, Wp = H // self.patch_size, W // self.patch_size + + x = self.patch_embed(x) + + outs = [] + for i, blk in enumerate(self.blocks[:-self.num_main_blocks]): + x = blk(x) + if i in self.out_indices: + x = x.reshape(B, Hp, Wp, *x.shape[-3:]).permute( + 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * x.shape[-3], + Wp * x.shape[-2]).contiguous() + outs.append(x) + + x = x[..., 0, 0, :] + if self.ape: + x = x + self.interpolate_pos_encoding(x, H, W) + x = self.pos_drop(x) + + rpe_index = True if self.rpe else None + + for i, blk in enumerate(self.blocks[-self.num_main_blocks:]): + x = blk(x, rpe_index) + if i in self.out_indices: + x = x.transpose(1, 2).view(B, -1, Hp, Wp).contiguous() + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.pos_drop.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + for param in self.fc_norm.parameters(): + param.requires_grad = False + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + self.num_layers = len(self.blocks) + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in 'pos_embed': + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/hornet.py b/mmpretrain/models/backbones/hornet.py new file mode 100644 index 0000000000000000000000000000000000000000..460f2dc57975712b5eae8308e2fca9c38b89a3e2 --- /dev/null +++ b/mmpretrain/models/backbones/hornet.py @@ -0,0 +1,500 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from official impl at https://github.com/raoyongming/HorNet. +try: + import torch.fft + fft = True +except ImportError: + fft = None + +import copy +from functools import partial +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import LayerScale + + +def get_dwconv(dim, kernel_size, bias=True): + """build a pepth-wise convolution.""" + return nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + bias=bias, + groups=dim) + + +class HorNetLayerNorm(nn.Module): + """An implementation of LayerNorm of HorNet. + + The differences between HorNetLayerNorm & torch LayerNorm: + 1. Supports two data formats channels_last or channels_first. + Args: + normalized_shape (int or list or torch.Size): input shape from an + expected input of size. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-5. + data_format (str): The ordering of the dimensions in the inputs. + channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with + shape (batch_size, channels, height, width). + Defaults to 'channels_last'. + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise ValueError( + 'data_format must be channels_last or channels_first') + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GlobalLocalFilter(nn.Module): + """A GlobalLocalFilter of HorNet. + + Args: + dim (int): Number of input channels. + h (int): Height of complex_weight. + Defaults to 14. + w (int): Width of complex_weight. + Defaults to 8. + """ + + def __init__(self, dim, h=14, w=8): + super().__init__() + self.dw = nn.Conv2d( + dim // 2, + dim // 2, + kernel_size=3, + padding=1, + bias=False, + groups=dim // 2) + self.complex_weight = nn.Parameter( + torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02) + self.pre_norm = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + self.post_norm = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + + def forward(self, x): + x = self.pre_norm(x) + x1, x2 = torch.chunk(x, 2, dim=1) + x1 = self.dw(x1) + + x2 = x2.to(torch.float32) + B, C, a, b = x2.shape + x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho') + + weight = self.complex_weight + if not weight.shape[1:3] == x2.shape[2:4]: + weight = F.interpolate( + weight.permute(3, 0, 1, 2), + size=x2.shape[2:4], + mode='bilinear', + align_corners=True).permute(1, 2, 3, 0) + + weight = torch.view_as_complex(weight.contiguous()) + + x2 = x2 * weight + x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho') + + x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], + dim=2).reshape(B, 2 * C, a, b) + x = self.post_norm(x) + return x + + +class gnConv(nn.Module): + """A gnConv of HorNet. + + Args: + dim (int): Number of input channels. + order (int): Order of gnConv. + Defaults to 5. + dw_cfg (dict): The Config for dw conv. + Defaults to ``dict(type='DW', kernel_size=7)``. + scale (float): Scaling parameter of gflayer outputs. + Defaults to 1.0. + """ + + def __init__(self, + dim, + order=5, + dw_cfg=dict(type='DW', kernel_size=7), + scale=1.0): + super().__init__() + self.order = order + self.dims = [dim // 2**i for i in range(order)] + self.dims.reverse() + self.proj_in = nn.Conv2d(dim, 2 * dim, 1) + + cfg = copy.deepcopy(dw_cfg) + dw_type = cfg.pop('type') + assert dw_type in ['DW', 'GF'],\ + 'dw_type should be `DW` or `GF`' + if dw_type == 'DW': + self.dwconv = get_dwconv(sum(self.dims), **cfg) + elif dw_type == 'GF': + self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg) + + self.proj_out = nn.Conv2d(dim, dim, 1) + + self.projs = nn.ModuleList([ + nn.Conv2d(self.dims[i], self.dims[i + 1], 1) + for i in range(order - 1) + ]) + + self.scale = scale + + def forward(self, x): + x = self.proj_in(x) + y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1) + + x = self.dwconv(x) * self.scale + + dw_list = torch.split(x, self.dims, dim=1) + x = y * dw_list[0] + + for i in range(self.order - 1): + x = self.projs[i](x) * dw_list[i + 1] + + x = self.proj_out(x) + + return x + + +class HorNetBlock(nn.Module): + """A block of HorNet. + + Args: + dim (int): Number of input channels. + order (int): Order of gnConv. + Defaults to 5. + dw_cfg (dict): The Config for dw conv. + Defaults to ``dict(type='DW', kernel_size=7)``. + scale (float): Scaling parameter of gflayer outputs. + Defaults to 1.0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_layer_scale (bool): Whether to use use_layer_scale in HorNet + block. Defaults to True. + """ + + def __init__(self, + dim, + order=5, + dw_cfg=dict(type='DW', kernel_size=7), + scale=1.0, + drop_path_rate=0., + use_layer_scale=True): + super().__init__() + self.out_channels = dim + + self.norm1 = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + self.gnconv = gnConv(dim, order, dw_cfg, scale) + self.norm2 = HorNetLayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + + if use_layer_scale: + self.gamma1 = LayerScale(dim, data_format='channels_first') + self.gamma2 = LayerScale(dim) + else: + self.gamma1, self.gamma2 = nn.Identity(), nn.Identity() + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x)))) + + input = x + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm2(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + x = self.gamma2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +@MODELS.register_module() +class HorNet(BaseBackbone): + """HorNet backbone. + + A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial + Interactions with Recursive Gated Convolutions + `_ . + Inspiration from https://github.com/raoyongming/HorNet + + Args: + arch (str | dict): HorNet architecture. + + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **base_dim** (int): The base dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **orders** (List[int]): The number of order of gnConv in each + stage. + - **dw_cfg** (List[dict]): The Config for dw conv. + + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3. + use_layer_scale (bool): Whether to use use_layer_scale in HorNet + block. Defaults to True. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'base_dim': 64, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['t-gf', 'tiny-gf'], + {'base_dim': 64, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['s', 'small'], + {'base_dim': 96, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['s-gf', 'small-gf'], + {'base_dim': 96, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['b', 'base'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['b-gf', 'base-gf'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['b-gf384', 'base-gf384'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=24, w=12), + dict(type='GF', h=13, w=7)]}), + **dict.fromkeys(['l', 'large'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['l-gf', 'large-gf'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['l-gf384', 'large-gf384'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=24, w=12), + dict(type='GF', h=13, w=7)]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + in_channels=3, + drop_path_rate=0., + scale=1 / 3, + use_layer_scale=True, + out_indices=(3, ), + frozen_stages=-1, + with_cp=False, + gap_before_final_norm=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if fft is None: + raise RuntimeError( + 'Failed to import torch.fft. Please install "torch>=1.7".') + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.scale = scale + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.with_cp = with_cp + self.gap_before_final_norm = gap_before_final_norm + + base_dim = self.arch_settings['base_dim'] + dims = list(map(lambda x: 2**x * base_dim, range(4))) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4), + HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + HorNetLayerNorm( + dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + total_depth = sum(self.arch_settings['depths']) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + self.stages = nn.ModuleList() + for i in range(4): + stage = nn.Sequential(*[ + HorNetBlock( + dim=dims[i], + order=self.arch_settings['orders'][i], + dw_cfg=self.arch_settings['dw_cfg'][i], + scale=self.scale, + drop_path_rate=dpr[cur_block_idx + j], + use_layer_scale=use_layer_scale) + for j in range(self.arch_settings['depths'][i]) + ]) + self.stages.append(stage) + cur_block_idx += self.arch_settings['depths'][i] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + norm_layer = partial( + HorNetLayerNorm, eps=1e-6, data_format='channels_first') + for i_layer in out_indices: + layer = norm_layer(dims[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + super(HorNet, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = self.downsample_layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i in self.out_indices: + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(4): + x = self.downsample_layers[i](x) + if self.with_cp: + x = checkpoint.checkpoint_sequential(self.stages[i], + len(self.stages[i]), x) + else: + x = self.stages[i](x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap).flatten(1)) + else: + # The output of LayerNorm2d may be discontiguous, which + # may cause some problem in the downstream tasks + outs.append(norm_layer(x).contiguous()) + return tuple(outs) diff --git a/mmpretrain/models/backbones/hrnet.py b/mmpretrain/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..99afa908531326f05ff1c977f0146a528683af43 --- /dev/null +++ b/mmpretrain/models/backbones/hrnet.py @@ -0,0 +1,563 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + + Args: + num_branches (int): The number of branches. + block (``BaseModule``): Convolution block module. + num_blocks (tuple): The number of blocks in each branch. + The length must be equal to ``num_branches``. + num_channels (tuple): The number of base channels in each branch. + The length must be equal to ``num_branches``. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + conv_cfg (dict, optional): Dictionary to construct and config conv + layer. Defaults to None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN')``. + block_init_cfg (dict, optional): The initialization configs of every + blocks. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_branches, + block, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + block_init_cfg=None, + init_cfg=None): + super(HRModule, self).__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, block, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_BLOCKS({len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_CHANNELS({len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + out_channels = num_channels[i] * get_expansion(block) + branches.append( + ResLayer( + block=block, + num_blocks=num_blocks[i], + in_channels=self.in_channels[i], + out_channels=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp, + init_cfg=self.block_init_cfg, + )) + + return ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + # Upsample the feature maps of smaller scales. + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + # Keep the feature map with the same scale. + fuse_layer.append(None) + else: + # Downsample the feature maps of larger scales. + conv_downsamples = [] + for k in range(i - j): + # Use stacked convolution layers to downsample. + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + `High-Resolution Representations for Labeling Pixels and Regions + `_. + + Args: + arch (str): The preset HRNet architecture, includes 'w18', 'w30', + 'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if + extra is ``None``. Defaults to 'w32'. + extra (dict, optional): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. Please choose between + 'BOTTLENECK' and 'BASIC'. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of base channels in each branch. + The length must be equal to num_branches. + + Defaults to None. + in_channels (int): Number of input image channels. Defaults to 3. + conv_cfg (dict, optional): Dictionary to construct and config conv + layer. Defaults to None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN')``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> import torch + >>> from mmpretrain.models import HRNet + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + arch_zoo = { + # num_modules, num_branches, block, num_blocks, num_channels + 'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (18, 36)], + [4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)], + [3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]], + 'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (30, 60)], + [4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)], + [3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]], + 'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (32, 64)], + [4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)], + [3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]], + 'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (40, 80)], + [4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)], + [3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]], + 'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (44, 88)], + [4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)], + [3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]], + 'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (48, 96)], + [4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)], + [3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]], + 'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (64, 128)], + [4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)], + [3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]], + } # yapf:disable + + def __init__(self, + arch='w32', + extra=None, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=False, + with_cp=False, + zero_init_residual=False, + multiscale_output=True, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(HRNet, self).__init__(init_cfg) + + extra = self.parse_arch(arch, extra) + + # Assert configurations of 4 stages are in extra + for i in range(1, 5): + assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".' + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + cfg = extra[f'stage{i}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + + # -------------------- stem net -------------------- + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.conv2 = build_conv_layer( + self.conv_cfg, + in_channels=64, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # -------------------- stage 1 -------------------- + self.stage1_cfg = self.extra['stage1'] + base_channels = self.stage1_cfg['num_channels'] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'] + + block = self.blocks_dict[block_type] + num_channels = [ + channel * get_expansion(block) for channel in base_channels + ] + # To align with the original code, use layer1 instead of stage1 here. + self.layer1 = ResLayer( + block, + in_channels=64, + out_channels=num_channels[0], + num_blocks=num_blocks[0]) + pre_num_channels = num_channels + + # -------------------- stage 2~4 -------------------- + for i in range(2, 5): + stage_cfg = self.extra[f'stage{i}'] + base_channels = stage_cfg['num_channels'] + block = self.blocks_dict[stage_cfg['block']] + multiscale_output_ = multiscale_output if i == 4 else True + + num_channels = [ + channel * get_expansion(block) for channel in base_channels + ] + # The transition layer from layer1 to stage2 + transition = self._make_transition_layer(pre_num_channels, + num_channels) + self.add_module(f'transition{i-1}', transition) + stage = self._make_stage( + stage_cfg, num_channels, multiscale_output=multiscale_output_) + self.add_module(f'stage{i}', stage) + + pre_num_channels = num_channels + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + # For existing scale branches, + # add conv block when the channels are not the same. + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + # For new scale branches, add stacked downsample conv blocks. + # For example, num_branches_pre = 2, for the 4th branch, add + # stacked two downsample conv blocks. + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules) + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [x] + + for i in range(2, 5): + # Apply transition + transition = getattr(self, f'transition{i-1}') + inputs = [] + for j, layer in enumerate(transition): + if j < len(x_list): + inputs.append(layer(x_list[j])) + else: + inputs.append(layer(x_list[-1])) + # Forward HRModule + stage = getattr(self, f'stage{i}') + x_list = stage(inputs) + + return tuple(x_list) + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(HRNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def parse_arch(self, arch, extra=None): + if extra is not None: + return extra + + assert arch in self.arch_zoo, \ + ('Invalid arch, please choose arch from ' + f'{list(self.arch_zoo.keys())}, or specify `extra` ' + 'argument directly.') + + extra = dict() + for i, stage_setting in enumerate(self.arch_zoo[arch], start=1): + extra[f'stage{i}'] = dict( + num_modules=stage_setting[0], + num_branches=stage_setting[1], + block=stage_setting[2], + num_blocks=stage_setting[3], + num_channels=stage_setting[4], + ) + + return extra diff --git a/mmpretrain/models/backbones/inception_v3.py b/mmpretrain/models/backbones/inception_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6c04b9fba4b50fce31539d14874dc7a47a539a --- /dev/null +++ b/mmpretrain/models/backbones/inception_v3.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class BasicConv2d(BaseModule): + """A basic convolution block including convolution, batch norm and ReLU. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict, optional): The config of convolution layer. + Defaults to None, which means to use ``nn.Conv2d``. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + **kwargs: Other keyword arguments of the convolution layer. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.conv = build_conv_layer( + conv_cfg, in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class InceptionA(BaseModule): + """Type-A Inception block. + + Args: + in_channels (int): The number of input channels. + pool_features (int): The number of channels in pooling branch. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + pool_features: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + + self.branch5x5_1 = BasicConv2d( + in_channels, 48, kernel_size=1, conv_cfg=conv_cfg) + self.branch5x5_2 = BasicConv2d( + 48, 64, kernel_size=5, padding=2, conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3 = BasicConv2d( + 96, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, pool_features, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionB(BaseModule): + """Type-B Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch3x3 = BasicConv2d( + in_channels, 384, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3 = BasicConv2d( + 96, 96, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool(x) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionC(BaseModule): + """Type-C Inception block. + + Args: + in_channels (int): The number of input channels. + channels_7x7 (int): The number of channels in 7x7 convolution branch. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + channels_7x7: int, + conv_cfg: Optional[dict] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + c7 = channels_7x7 + self.branch7x7_1 = BasicConv2d( + in_channels, c7, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7_2 = BasicConv2d( + c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7_3 = BasicConv2d( + c7, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + + self.branch7x7dbl_1 = BasicConv2d( + in_channels, c7, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7dbl_2 = BasicConv2d( + c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7dbl_3 = BasicConv2d( + c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7dbl_4 = BasicConv2d( + c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7dbl_5 = BasicConv2d( + c7, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionD(BaseModule): + """Type-D Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch3x3_1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3_2 = BasicConv2d( + 192, 320, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch7x7x3_1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7x3_2 = BasicConv2d( + 192, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7x3_3 = BasicConv2d( + 192, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7x3_4 = BasicConv2d( + 192, 192, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = self.branch_pool(x) + outputs = [branch3x3, branch7x7x3, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionE(BaseModule): + """Type-E Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 320, kernel_size=1, conv_cfg=conv_cfg) + + self.branch3x3_1 = BasicConv2d( + in_channels, 384, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3_2a = BasicConv2d( + 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg) + self.branch3x3_2b = BasicConv2d( + 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 448, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 448, 384, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3a = BasicConv2d( + 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg) + self.branch3x3dbl_3b = BasicConv2d( + 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionAux(BaseModule): + """The Inception block for the auxiliary classification branch. + + Args: + in_channels (int): The number of input channels. + num_classes (int): The number of categroies. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to use trunc normal with ``std=0.01`` for Conv2d layers + and use trunc normal with ``std=0.001`` for Linear layers.. + """ + + def __init__(self, + in_channels: int, + num_classes: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = [ + dict(type='TruncNormal', layer='Conv2d', std=0.01), + dict(type='TruncNormal', layer='Linear', std=0.001) + ]): + super().__init__(init_cfg=init_cfg) + self.downsample = nn.AvgPool2d(kernel_size=5, stride=3) + self.conv0 = BasicConv2d( + in_channels, 128, kernel_size=1, conv_cfg=conv_cfg) + self.conv1 = BasicConv2d(128, 768, kernel_size=5, conv_cfg=conv_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(768, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + # N x 768 x 17 x 17 + x = self.downsample(x) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = self.gap(x) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +@MODELS.register_module() +class InceptionV3(BaseBackbone): + """Inception V3 backbone. + + A PyTorch implementation of `Rethinking the Inception Architecture for + Computer Vision `_ + + This implementation is modified from + https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py. + Licensed under the BSD 3-Clause License. + + Args: + num_classes (int): The number of categroies. Defaults to 1000. + aux_logits (bool): Whether to enable the auxiliary branch. If False, + the auxiliary logits output will be None. Defaults to False. + dropout (float): Dropout rate. Defaults to 0.5. + init_cfg (dict, optional): The config of initialization. Defaults + to use trunc normal with ``std=0.1`` for all Conv2d and Linear + layers and constant with ``val=1`` for all BatchNorm2d layers. + + Example: + >>> import torch + >>> from mmpretrain.models import build_backbone + >>> + >>> inputs = torch.rand(2, 3, 299, 299) + >>> cfg = dict(type='InceptionV3', num_classes=100) + >>> backbone = build_backbone(cfg) + >>> aux_out, out = backbone(inputs) + >>> # The auxiliary branch is disabled by default. + >>> assert aux_out is None + >>> print(out.shape) + torch.Size([2, 100]) + >>> cfg = dict(type='InceptionV3', num_classes=100, aux_logits=True) + >>> backbone = build_backbone(cfg) + >>> aux_out, out = backbone(inputs) + >>> print(aux_out.shape, out.shape) + torch.Size([2, 100]) torch.Size([2, 100]) + """ + + def __init__( + self, + num_classes: int = 1000, + aux_logits: bool = False, + dropout: float = 0.5, + init_cfg: Optional[dict] = [ + dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=0.1), + dict(type='Constant', layer='BatchNorm2d', val=1) + ], + ) -> None: + super().__init__(init_cfg=init_cfg) + + self.aux_logits = aux_logits + self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + self.AuxLogits: Optional[nn.Module] = None + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(p=dropout) + self.fc = nn.Linear(2048, num_classes) + + def forward( + self, + x: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function.""" + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.maxpool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.maxpool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + aux: Optional[torch.Tensor] = None + if self.aux_logits and self.training: + aux = self.AuxLogits(x) + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + # Adaptive average pooling + x = self.avgpool(x) + # N x 2048 x 1 x 1 + x = self.dropout(x) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + x = self.fc(x) + # N x 1000 (num_classes) + return aux, x diff --git a/mmpretrain/models/backbones/lenet.py b/mmpretrain/models/backbones/lenet.py new file mode 100644 index 0000000000000000000000000000000000000000..8e423c0b15a60660714617e47fd68857b3a6d1e0 --- /dev/null +++ b/mmpretrain/models/backbones/lenet.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class LeNet5(BaseBackbone): + """`LeNet5 `_ backbone. + + The input for LeNet-5 is a 32×32 grayscale image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super(LeNet5, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh()) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(120, 84), + nn.Tanh(), + nn.Linear(84, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = self.classifier(x.squeeze()) + + return (x, ) diff --git a/mmpretrain/models/backbones/levit.py b/mmpretrain/models/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7aa324e28b1725fb9e67110a26ea2d5c2831bd --- /dev/null +++ b/mmpretrain/models/backbones/levit.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, fuse_conv_bn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +class HybridBackbone(BaseModule): + + def __init__( + self, + embed_dim, + kernel_size=3, + stride=2, + pad=1, + dilation=1, + groups=1, + act_cfg=dict(type='HSwish'), + conv_cfg=None, + norm_cfg=dict(type='BN'), + init_cfg=None, + ): + super(HybridBackbone, self).__init__(init_cfg=init_cfg) + + self.input_channels = [ + 3, embed_dim // 8, embed_dim // 4, embed_dim // 2 + ] + self.output_channels = [ + embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim + ] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.patch_embed = Sequential() + + for i in range(len(self.input_channels)): + conv_bn = ConvolutionBatchNorm( + self.input_channels[i], + self.output_channels[i], + kernel_size=kernel_size, + stride=stride, + pad=pad, + dilation=dilation, + groups=groups, + norm_cfg=norm_cfg, + ) + self.patch_embed.add_module('%d' % (2 * i), conv_bn) + if i < len(self.input_channels) - 1: + self.patch_embed.add_module('%d' % (i * 2 + 1), + build_activation_layer(act_cfg)) + + def forward(self, x): + x = self.patch_embed(x) + return x + + +class ConvolutionBatchNorm(BaseModule): + + def __init__( + self, + in_channel, + out_channel, + kernel_size=3, + stride=2, + pad=1, + dilation=1, + groups=1, + norm_cfg=dict(type='BN'), + ): + super(ConvolutionBatchNorm, self).__init__() + self.conv = nn.Conv2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=pad, + dilation=dilation, + groups=groups, + bias=False) + self.bn = build_norm_layer(norm_cfg, out_channel) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + @torch.no_grad() + def fuse(self): + return fuse_conv_bn(self).conv + + +class LinearBatchNorm(BaseModule): + + def __init__(self, in_feature, out_feature, norm_cfg=dict(type='BN1d')): + super(LinearBatchNorm, self).__init__() + self.linear = nn.Linear(in_feature, out_feature, bias=False) + self.bn = build_norm_layer(norm_cfg, out_feature) + + def forward(self, x): + x = self.linear(x) + x = self.bn(x.flatten(0, 1)).reshape_as(x) + return x + + @torch.no_grad() + def fuse(self): + w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 + w = self.linear.weight * w[:, None] + b = self.bn.bias - self.bn.running_mean * self.bn.weight / \ + (self.bn.running_var + self.bn.eps) ** 0.5 + + factory_kwargs = { + 'device': self.linear.weight.device, + 'dtype': self.linear.weight.dtype + } + bias = nn.Parameter( + torch.empty(self.linear.out_features, **factory_kwargs)) + self.linear.register_parameter('bias', bias) + self.linear.weight.data.copy_(w) + self.linear.bias.data.copy_(b) + return self.linear + + +class Residual(BaseModule): + + def __init__(self, block, drop_path_rate=0.): + super(Residual, self).__init__() + self.block = block + if drop_path_rate > 0: + self.drop_path = DropPath(drop_path_rate) + else: + self.drop_path = nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.block(x)) + return x + + +class Attention(BaseModule): + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + act_cfg=dict(type='HSwish'), + resolution=14, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = LinearBatchNorm(dim, h) + self.proj = nn.Sequential( + build_activation_layer(act_cfg), LinearBatchNorm(self.dh, dim)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + """change the mode of model.""" + super(Attention, self).train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape # 2 196 128 + qkv = self.qkv(x) # 2 196 128 + q, k, v = qkv.view(B, N, self.num_heads, -1).split( + [self.key_dim, self.key_dim, self.d], + dim=3) # q 2 196 4 16 ; k 2 196 4 16; v 2 196 4 32 + q = q.permute(0, 2, 1, 3) # 2 4 196 16 + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ((q @ k.transpose(-2, -1)) * + self.scale # 2 4 196 16 * 2 4 16 196 -> 2 4 196 196 + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) # 2 4 196 196 -> 2 4 196 196 + x = (attn @ v).transpose(1, 2).reshape( + B, N, + self.dh) # 2 4 196 196 * 2 4 196 32 -> 2 4 196 32 -> 2 196 128 + x = self.proj(x) + return x + + +class MLP(nn.Sequential): + + def __init__(self, embed_dim, mlp_ratio, act_cfg=dict(type='HSwish')): + super(MLP, self).__init__() + h = embed_dim * mlp_ratio + self.linear1 = LinearBatchNorm(embed_dim, h) + self.activation = build_activation_layer(act_cfg) + self.linear2 = LinearBatchNorm(h, embed_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.activation(x) + x = self.linear2(x) + return x + + +class Subsample(BaseModule): + + def __init__(self, stride, resolution): + super(Subsample, self).__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, _, C = x.shape + # B, N, C -> B, H, W, C + x = x.view(B, self.resolution, self.resolution, C) + x = x[:, ::self.stride, ::self.stride] + x = x.reshape(B, -1, C) # B, H', W', C -> B, N', C + return x + + +class AttentionSubsample(nn.Sequential): + + def __init__(self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=2, + act_cfg=dict(type='HSwish'), + stride=2, + resolution=14): + super(AttentionSubsample, self).__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.sub_resolution = (resolution - 1) // stride + 1 + h = self.dh + nh_kd + self.kv = LinearBatchNorm(in_dim, h) + + self.q = nn.Sequential( + Subsample(stride, resolution), LinearBatchNorm(in_dim, nh_kd)) + self.proj = nn.Sequential( + build_activation_layer(act_cfg), LinearBatchNorm(self.dh, out_dim)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + sub_points = list( + itertools.product( + range(self.sub_resolution), range(self.sub_resolution))) + N = len(points) + N_sub = len(sub_points) + attention_offsets = {} + idxs = [] + for p1 in sub_points: + for p2 in points: + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N_sub, N)) + + @torch.no_grad() + def train(self, mode=True): + super(AttentionSubsample, self).train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, + -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.sub_resolution**2, self.num_heads, + self.key_dim).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + \ + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +@MODELS.register_module() +class LeViT(BaseBackbone): + """LeViT backbone. + + A PyTorch implementation of `LeViT: A Vision Transformer in ConvNet's + Clothing for Faster Inference `_ + + Modified from the official implementation: + https://github.com/facebookresearch/LeViT + + Args: + arch (str | dict): LeViT architecture. + + If use string, choose from '128s', '128', '192', '256' and '384'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The embed dimensions of each stage. + - **key_dims** (List[int]): The embed dimensions of the key in the + attention layers of each stage. + - **num_heads** (List[int]): The number of heads in each stage. + - **depths** (List[int]): The number of blocks in each stage. + + img_size (int): Input image size + patch_size (int | tuple): The patch size. Deault to 16 + attn_ratio (int): Ratio of hidden dimensions of the value in attention + layers. Defaults to 2. + mlp_ratio (int): Ratio of hidden dimensions in MLP layers. + Defaults to 2. + act_cfg (dict): The config of activation functions. + Defaults to ``dict(type='HSwish')``. + hybrid_backbone (callable): A callable object to build the patch embed + module. Defaults to use :class:`HybridBackbone`. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + deploy (bool): Whether to switch the model structure to + deployment mode. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + '128s': { + 'embed_dims': [128, 256, 384], + 'num_heads': [4, 6, 8], + 'depths': [2, 3, 4], + 'key_dims': [16, 16, 16], + }, + '128': { + 'embed_dims': [128, 256, 384], + 'num_heads': [4, 8, 12], + 'depths': [4, 4, 4], + 'key_dims': [16, 16, 16], + }, + '192': { + 'embed_dims': [192, 288, 384], + 'num_heads': [3, 5, 6], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + '256': { + 'embed_dims': [256, 384, 512], + 'num_heads': [4, 6, 8], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + '384': { + 'embed_dims': [384, 512, 768], + 'num_heads': [6, 9, 12], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + } + + def __init__(self, + arch, + img_size=224, + patch_size=16, + attn_ratio=2, + mlp_ratio=2, + act_cfg=dict(type='HSwish'), + hybrid_backbone=HybridBackbone, + out_indices=-1, + deploy=False, + drop_path_rate=0, + init_cfg=None): + super(LeViT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch = self.arch_zoo[arch] + elif isinstance(arch, dict): + essential_keys = {'embed_dim', 'num_heads', 'depth', 'key_dim'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch = arch + else: + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + self.embed_dims = self.arch['embed_dims'] + self.num_heads = self.arch['num_heads'] + self.key_dims = self.arch['key_dims'] + self.depths = self.arch['depths'] + self.num_stages = len(self.embed_dims) + self.drop_path_rate = drop_path_rate + + self.patch_embed = hybrid_backbone(self.embed_dims[0]) + + self.resolutions = [] + resolution = img_size // patch_size + self.stages = ModuleList() + for i, (embed_dims, key_dims, depth, num_heads) in enumerate( + zip(self.embed_dims, self.key_dims, self.depths, + self.num_heads)): + blocks = [] + if i > 0: + downsample = AttentionSubsample( + in_dim=self.embed_dims[i - 1], + out_dim=embed_dims, + key_dim=key_dims, + num_heads=self.embed_dims[i - 1] // key_dims, + attn_ratio=4, + act_cfg=act_cfg, + stride=2, + resolution=resolution) + blocks.append(downsample) + resolution = downsample.sub_resolution + if mlp_ratio > 0: # mlp_ratio + blocks.append( + Residual( + MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), + self.drop_path_rate)) + self.resolutions.append(resolution) + for _ in range(depth): + blocks.append( + Residual( + Attention( + embed_dims, + key_dims, + num_heads, + attn_ratio=attn_ratio, + act_cfg=act_cfg, + resolution=resolution, + ), self.drop_path_rate)) + if mlp_ratio > 0: + blocks.append( + Residual( + MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), + self.drop_path_rate)) + + self.stages.append(Sequential(*blocks)) + + if isinstance(out_indices, int): + out_indices = [out_indices] + elif isinstance(out_indices, tuple): + out_indices = list(out_indices) + elif not isinstance(out_indices, list): + raise TypeError('"out_indices" must by a list, tuple or int, ' + f'get {type(out_indices)} instead.') + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] < self.num_stages, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + self.deploy = False + if deploy: + self.switch_to_deploy() + + def switch_to_deploy(self): + if self.deploy: + return + fuse_parameters(self) + self.deploy = True + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) # B, C, H, W -> B, L, C + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + B, _, C = x.shape + if i in self.out_indices: + out = x.reshape(B, self.resolutions[i], self.resolutions[i], C) + out = out.permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + +def fuse_parameters(module): + for child_name, child in module.named_children(): + if hasattr(child, 'fuse'): + setattr(module, child_name, child.fuse()) + else: + fuse_parameters(child) diff --git a/mmpretrain/models/backbones/mixmim.py b/mmpretrain/models/backbones/mixmim.py new file mode 100644 index 0000000000000000000000000000000000000000..2c67aa0c3a45c5c85adbacb94ae90dc170b2d0bb --- /dev/null +++ b/mmpretrain/models/backbones/mixmim.py @@ -0,0 +1,533 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging +from mmengine.model import BaseModule +from torch import nn +from torch.utils.checkpoint import checkpoint + +from mmpretrain.registry import MODELS +from ..utils import WindowMSA, to_2tuple +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class MixMIMWindowAttention(WindowMSA): + """MixMIM Window Attention. + + Compared with WindowMSA, we add some modifications + in ``forward`` to meet the requirement of MixMIM during + pretraining. + + Implements one windown attention in MixMIM. + Args: + embed_dims (int): The feature dimension. + window_size (list): The height and width of the window. + num_heads (int): The number of head in attention. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=proj_drop_rate, + init_cfg=init_cfg) + + def forward(self, x, mask=None): + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + mask = mask.reshape(B_, 1, 1, N) + mask_new = mask * mask.transpose( + 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3) + mask_new = 1 - mask_new + + if mask_new.dtype == torch.float16: + attn = attn - 65500 * mask_new + else: + attn = attn - 1e30 * mask_new + + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MixMIMBlock(TransformerEncoderLayer): + """MixMIM Block. Implements one block in MixMIM. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + num_fcs (int): The number of linear layers in a block. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + num_fcs=2, + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(mlp_ratio * embed_dims), + drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + self.window_size = min(self.input_resolution) + + self.attn = MixMIMWindowAttention( + embed_dims=embed_dims, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate) + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + def forward(self, x, attn_mask=None): + H, W = self.input_resolution + B, L, C = x.shape + + shortcut = x + x = self.ln1(x) + x = x.view(B, H, W, C) + + # partition windows + x_windows = self.window_partition( + x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + if attn_mask is not None: + attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1 + attn_mask = attn_mask.view(B, H, W, 1) + attn_mask = self.window_partition(attn_mask, self.window_size) + attn_mask = attn_mask.view(-1, self.window_size * self.window_size, + 1) + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + x = self.window_reverse(attn_windows, H, W, + self.window_size) # B H' W' C + + x = x.view(B, H * W, C) + + x = shortcut + self.drop_path(x) + + x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath + + return x + + +class MixMIMLayer(BaseModule): + """Implements one MixMIM layer, which may contains several MixMIM blocks. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + depth (int): The number of blocks in this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + downsample (class, optional): Downsample the output of blocks b + y patch merging.Defaults to None. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + input_resolution: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio=4., + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=[0.], + norm_cfg=dict(type='LN'), + downsample=None, + use_checkpoint=False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + self.blocks.append( + MixMIMBlock( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate[i], + norm_cfg=norm_cfg)) + # patch merging layer + if downsample is not None: + self.downsample = downsample( + in_channels=embed_dims, + out_channels=2 * embed_dims, + norm_cfg=norm_cfg) + else: + self.downsample = None + + def forward(self, x, attn_mask=None): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask=attn_mask) + if self.downsample is not None: + x, _ = self.downsample(x, self.input_resolution) + return x + + def extra_repr(self) -> str: + return f'dim={self.embed_dims}, \ + input_resolution={self.input_resolution}, depth={self.depth}' + + +@MODELS.register_module() +class MixMIMTransformer(BaseBackbone): + """MixMIM backbone. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + attn_drop_rate (float): attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [11, 22, 44, 88] + }), + } + + def __init__( + self, + arch='base', + mlp_ratio=4, + img_size=224, + patch_size=4, + in_channels=3, + window_size=[14, 14, 14, 7], + qkv_bias=True, + patch_cfg=dict(), + norm_cfg=dict(type='LN'), + drop_rate=0.0, + drop_path_rate=0.0, + attn_drop_rate=0.0, + use_checkpoint=False, + init_cfg: Optional[dict] = None, + ) -> None: + super(MixMIMTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.encoder_stride = 32 + + self.num_layers = len(self.depths) + self.qkv_bias = qkv_bias + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.use_checkpoint = use_checkpoint + self.mlp_ratio = mlp_ratio + self.window_size = window_size + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + self.layers.append( + MixMIMLayer( + embed_dims=int(self.embed_dims * 2**i_layer), + input_resolution=(self.patch_resolution[0] // (2**i_layer), + self.patch_resolution[1] // + (2**i_layer)), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + proj_drop_rate=self.drop_rate, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:i_layer] + ):sum(self.depths[:i_layer + + 1])], + norm_cfg=norm_cfg, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=self.use_checkpoint)) + + self.num_features = int(self.embed_dims * 2**(self.num_layers - 1)) + self.drop_after_pos = nn.Dropout(p=self.drop_rate) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, self.embed_dims), + requires_grad=False) + + _, self.norm = build_norm_layer(norm_cfg, self.num_features) + + def forward(self, x: torch.Tensor): + x, _ = self.patch_embed(x) + + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for layer in self.layers: + x = layer(x, attn_mask=None) + + x = self.norm(x) + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return (x, ) + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = sum(self.depths) + 2 + + if not param_name.startswith(prefix): + # For subsequent module like neck and head + if param_name.startswith('neck'): + return num_layers - 2, num_layers + else: + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + stem_layers = ('patch_embed', 'absolute_pos_embed', 'pos_embed') + if any(stem in param_name for stem in stem_layers): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + block_id = param_name.split('.')[3] + + if block_id in ('downsample', 'reduction', 'norm'): + layer_depth = sum(self.depths[:layer_id + 1]) + else: + layer_depth = sum(self.depths[:layer_id]) + int(block_id) + 1 + else: + layer_depth = num_layers - 2 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/mlp_mixer.py b/mmpretrain/models/backbones/mlp_mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..26fb8ce0186c2451a5698c413ebf2bc24f33b6ec --- /dev/null +++ b/mmpretrain/models/backbones/mlp_mixer.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import to_2tuple +from .base_backbone import BaseBackbone + + +class MixerBlock(BaseModule): + """Mlp-Mixer basic block. + + Basic module of `MLP-Mixer: An all-MLP Architecture for Vision + `_ + + Args: + num_tokens (int): The number of patched tokens + embed_dims (int): The feature dimension + tokens_mlp_dims (int): The hidden dimension for tokens FFNs + channels_mlp_dims (int): The hidden dimension for channels FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_tokens, + embed_dims, + tokens_mlp_dims, + channels_mlp_dims, + drop_rate=0., + drop_path_rate=0., + num_fcs=2, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(MixerBlock, self).__init__(init_cfg=init_cfg) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + self.token_mix = FFN( + embed_dims=num_tokens, + feedforward_channels=tokens_mlp_dims, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + self.channel_mix = FFN( + embed_dims=embed_dims, + feedforward_channels=channels_mlp_dims, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def init_weights(self): + super(MixerBlock, self).init_weights() + for m in self.token_mix.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + for m in self.channel_mix.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + out = self.norm1(x).transpose(1, 2) + x = x + self.token_mix(out).transpose(1, 2) + x = self.channel_mix(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class MlpMixer(BaseBackbone): + """Mlp-Mixer backbone. + + Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision + `_ + + Args: + arch (str | dict): MLP Mixer architecture. If use string, choose from + 'small', 'base' and 'large'. If use dict, it should have below + keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of MLP blocks. + - **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs. + - **channels_mlp_dims** (int): The The hidden dimensions for + channels FFNs. + + Defaults to 'base'. + img_size (int | tuple): The input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + out_indices (Sequence | int): Output from which layer. + Defaults to -1, means the last layer. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + act_cfg (dict): The activation config for FFNs. Default GELU. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each mixer block layer. + Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 512, + 'num_layers': 8, + 'tokens_mlp_dims': 256, + 'channels_mlp_dims': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'tokens_mlp_dims': 384, + 'channels_mlp_dims': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'tokens_mlp_dims': 512, + 'channels_mlp_dims': 4096, + }), + } + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(MlpMixer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'tokens_mlp_dims', + 'channels_mlp_dims' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.tokens_mlp_dims = self.arch_settings['tokens_mlp_dims'] + self.channels_mlp_dims = self.arch_settings['channels_mlp_dims'] + + self.img_size = to_2tuple(img_size) + + _patch_cfg = dict( + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must be a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + else: + assert index >= self.num_layers, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + num_tokens=num_patches, + embed_dims=self.embed_dims, + tokens_mlp_dims=self.tokens_mlp_dims, + channels_mlp_dims=self.channels_mlp_dims, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(MixerBlock(**_layer_cfg)) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + assert x.shape[2:] == self.img_size, \ + "The MLP-Mixer doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + x, _ = self.patch_embed(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1: + x = self.norm1(x) + + if i in self.out_indices: + out = x.transpose(1, 2) + outs.append(out) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/mobilenet_v2.py b/mmpretrain/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..bca1418a13c4ed81c4666e7f53b0417c36b2e99b --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v2.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import make_divisible +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class InvertedResidual(BaseModule): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class MobileNetV2(BaseBackbone): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileNetV2, self).__init__(init_cfg) + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/mobilenet_v3.py b/mmpretrain/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..577dba94040dec5ecda9388819b8b5205f307dce --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v3.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import InvertedResidual +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class MobileNetV3(BaseBackbone): + """MobileNetV3 backbone. + + Args: + arch (str): Architecture of mobilnetv3, from {small, large}. + Default: small. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (None or Sequence[int]): Output from which stages. + Default: None, which means output tensors from final stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'small_075': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 32, True, 'HSwish', 2], + [5, 192, 32, True, 'HSwish', 1], + [5, 192, 32, True, 'HSwish', 1], + [5, 96, 40, True, 'HSwish', 1], + [5, 120, 40, True, 'HSwish', 1], + [5, 240, 72, True, 'HSwish', 2], + [5, 432, 72, True, 'HSwish', 1], + [5, 432, 72, True, 'HSwish', 1]], + 'small_050': [[3, 16, 8, True, 'ReLU', 2], + [3, 40, 16, False, 'ReLU', 2], + [3, 56, 16, False, 'ReLU', 1], + [5, 64, 24, True, 'HSwish', 2], + [5, 144, 24, True, 'HSwish', 1], + [5, 144, 24, True, 'HSwish', 1], + [5, 72, 24, True, 'HSwish', 1], + [5, 72, 24, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 2], + [5, 288, 48, True, 'HSwish', 1], + [5, 288, 48, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], + [3, 64, 24, False, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN', eps=0.001, momentum=0.01), + out_indices=None, + frozen_stages=-1, + norm_eval=False, + with_cp=False, + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + nonlinearity='leaky_relu'), + dict(type='Normal', layer=['Linear'], std=0.01), + dict(type='Constant', layer=['BatchNorm2d'], val=1) + ]): + super(MobileNetV3, self).__init__(init_cfg) + assert arch in self.arch_settings + if out_indices is None: + out_indices = (12, ) if 'small' in arch else (16, ) + for order, index in enumerate(out_indices): + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch]) + 2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch]) + 2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layers = self._make_layer() + self.feat_dim = self.arch_settings[arch][-1][1] + + def _make_layer(self): + layers = [] + layer_setting = self.arch_settings[self.arch] + in_channels = 16 + + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict( + type='HSigmoid', + bias=3, + divisor=6, + min_value=0, + max_value=1))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # Build the last layer before pooling + # TODO: No dilation + layer = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = 'layer{}'.format(len(layer_setting) + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV3, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/mobileone.py b/mmpretrain/models/backbones/mobileone.py new file mode 100644 index 0000000000000000000000000000000000000000..1111441af82d43a49d15ecbb5dc0778fc9f87596 --- /dev/null +++ b/mmpretrain/models/backbones/mobileone.py @@ -0,0 +1,515 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from official impl https://github.com/apple/ml-mobileone/blob/main/mobileone.py # noqa: E501 +from typing import Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .base_backbone import BaseBackbone + + +class MobileOneBlock(BaseModule): + """MobileOne block for MobileOne backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + kernel_size (int): The kernel size of the convs in the block. If the + kernel size is large than 1, there will be a ``branch_scale`` in + the block. + num_convs (int): Number of the convolution branches in the block. + stride (int): Stride of convolution layers. Defaults to 1. + padding (int): Padding of the convolution layers. Defaults to 1. + dilation (int): Dilation of the convolution layers. Defaults to 1. + groups (int): Groups of the convolution layers. Defaults to 1. + se_cfg (None or dict): The configuration of the se module. + Defaults to None. + norm_cfg (dict): Configuration to construct and config norm layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + deploy (bool): Whether the model structure is in the deployment mode. + Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int, + num_convs: int, + stride: int = 1, + padding: int = 1, + dilation: int = 1, + groups: int = 1, + se_cfg: Optional[dict] = None, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='ReLU'), + deploy: bool = False, + init_cfg: Optional[dict] = None): + super(MobileOneBlock, self).__init__(init_cfg) + + assert se_cfg is None or isinstance(se_cfg, dict) + if se_cfg is not None: + self.se = SELayer(channels=out_channels, **se_cfg) + else: + self.se = nn.Identity() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_conv_branches = num_convs + self.stride = stride + self.padding = padding + self.se_cfg = se_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.deploy = deploy + self.groups = groups + self.dilation = dilation + + if deploy: + self.branch_reparam = build_conv_layer( + conv_cfg, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=self.groups, + stride=stride, + padding=padding, + dilation=dilation, + bias=True) + else: + # judge if input shape and output shape are the same. + # If true, add a normalized identity shortcut. + if out_channels == in_channels and stride == 1: + self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.branch_norm = None + + self.branch_scale = None + if kernel_size > 1: + self.branch_scale = self.create_conv_bn(kernel_size=1) + + self.branch_conv_list = ModuleList() + for _ in range(num_convs): + self.branch_conv_list.append( + self.create_conv_bn( + kernel_size=kernel_size, + padding=padding, + dilation=dilation)) + + self.act = build_activation_layer(act_cfg) + + def create_conv_bn(self, kernel_size, dilation=1, padding=0): + """cearte a (conv + bn) Sequential layer.""" + conv_bn = Sequential() + conv_bn.add_module( + 'conv', + build_conv_layer( + self.conv_cfg, + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + groups=self.groups, + stride=self.stride, + dilation=dilation, + padding=padding, + bias=False)) + conv_bn.add_module( + 'norm', + build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) + + return conv_bn + + def forward(self, x): + + def _inner_forward(inputs): + if self.deploy: + return self.branch_reparam(inputs) + + inner_out = 0 + if self.branch_norm is not None: + inner_out = self.branch_norm(inputs) + + if self.branch_scale is not None: + inner_out += self.branch_scale(inputs) + + for branch_conv in self.branch_conv_list: + inner_out += branch_conv(inputs) + + return inner_out + + return self.act(self.se(_inner_forward(x))) + + def switch_to_deploy(self): + """Switch the model structure from training mode to deployment mode.""" + if self.deploy: + return + assert self.norm_cfg['type'] == 'BN', \ + "Switch is not allowed when norm_cfg['type'] != 'BN'." + + reparam_weight, reparam_bias = self.reparameterize() + self.branch_reparam = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + self.branch_reparam.weight.data = reparam_weight + self.branch_reparam.bias.data = reparam_bias + + for param in self.parameters(): + param.detach_() + delattr(self, 'branch_conv_list') + if hasattr(self, 'branch_scale'): + delattr(self, 'branch_scale') + delattr(self, 'branch_norm') + + self.deploy = True + + def reparameterize(self): + """Fuse all the parameters of all branches. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all + branches. the first element is the weights and the second is + the bias. + """ + weight_conv, bias_conv = 0, 0 + for branch_conv in self.branch_conv_list: + weight, bias = self._fuse_conv_bn(branch_conv) + weight_conv += weight + bias_conv += bias + + weight_scale, bias_scale = 0, 0 + if self.branch_scale is not None: + weight_scale, bias_scale = self._fuse_conv_bn(self.branch_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + weight_scale = F.pad(weight_scale, [pad, pad, pad, pad]) + + weight_norm, bias_norm = 0, 0 + if self.branch_norm: + tmp_conv_bn = self._norm_to_conv(self.branch_norm) + weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) + + return (weight_conv + weight_scale + weight_norm, + bias_conv + bias_scale + bias_norm) + + def _fuse_conv_bn(self, branch): + """Fuse the parameters in a branch with a conv and bn. + + Args: + branch (mmcv.runner.Sequential): A branch with conv and bn. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + + std = (running_var + eps).sqrt() + fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * kernel + fused_bias = beta - running_mean * gamma / std + + return fused_weight, fused_bias + + def _norm_to_conv(self, branch_nrom): + """Convert a norm layer to a conv-bn sequence towards + ``self.kernel_size``. + + Args: + branch (nn.BatchNorm2d): A branch only with bn in the block. + + Returns: + (mmcv.runner.Sequential): a sequential with conv and bn. + """ + input_dim = self.in_channels // self.groups + conv_weight = torch.zeros( + (self.in_channels, input_dim, self.kernel_size, self.kernel_size), + dtype=branch_nrom.weight.dtype) + + for i in range(self.in_channels): + conv_weight[i, i % input_dim, self.kernel_size // 2, + self.kernel_size // 2] = 1 + conv_weight = conv_weight.to(branch_nrom.weight.device) + + tmp_conv = self.create_conv_bn(kernel_size=self.kernel_size) + tmp_conv.conv.weight.data = conv_weight + tmp_conv.norm = branch_nrom + return tmp_conv + + +@MODELS.register_module() +class MobileOne(BaseBackbone): + """MobileOne backbone. + + A PyTorch impl of : `An Improved One millisecond Mobile Backbone + `_ + + Args: + arch (str | dict): MobileOne architecture. If use string, choose + from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should + have below keys: + + - num_blocks (Sequence[int]): Number of blocks in each stage. + - width_factor (Sequence[float]): Width factor in each stage. + - num_conv_branches (Sequence[int]): Number of conv branches + in each stage. + - num_se_blocks (Sequence[int]): Number of SE layers in each + stage, all the SE layers are placed in the subsequent order + in each stage. + + Defaults to 's0'. + in_channels (int): Number of input image channels. Default: 3. + out_indices (Sequence[int] | int): Output from which stages. + Defaults to ``(3, )``. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + deploy (bool): Whether to switch the model structure to deployment + mode. Defaults to False. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> from mmpretrain.models import MobileOne + >>> import torch + >>> x = torch.rand(1, 3, 224, 224) + >>> model = MobileOne("s0", out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> outputs = model(x) + >>> for out in outputs: + ... print(tuple(out.shape)) + (1, 48, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 1024, 7, 7) + """ + + arch_zoo = { + 's0': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[0.75, 1.0, 1.0, 2.0], + num_conv_branches=[4, 4, 4, 4], + num_se_blocks=[0, 0, 0, 0]), + 's1': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[1.5, 1.5, 2.0, 2.5], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's2': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[1.5, 2.0, 2.5, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's3': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[2.0, 2.5, 3.0, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's4': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[3.0, 3.5, 3.5, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 5, 1]) + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + se_cfg=dict(ratio=16), + deploy=False, + norm_eval=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm']) + ]): + super(MobileOne, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_zoo, f'"arch": "{arch}"' \ + f' is not one of the {list(self.arch_zoo.keys())}' + arch = self.arch_zoo[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + self.arch = arch + for k, value in self.arch.items(): + assert isinstance(value, list) and len(value) == 4, \ + f'the value of {k} in arch must be list with 4 items.' + + self.in_channels = in_channels + self.deploy = deploy + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.se_cfg = se_cfg + self.act_cfg = act_cfg + + base_channels = [64, 128, 256, 512] + channels = min(64, + int(base_channels[0] * self.arch['width_factor'][0])) + self.stage0 = MobileOneBlock( + self.in_channels, + channels, + stride=2, + kernel_size=3, + num_convs=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + deploy=deploy) + + self.in_planes = channels + self.stages = [] + for i, num_blocks in enumerate(self.arch['num_blocks']): + planes = int(base_channels[i] * self.arch['width_factor'][i]) + + stage = self._make_stage(planes, num_blocks, + arch['num_se_blocks'][i], + arch['num_conv_branches'][i]) + + stage_name = f'stage{i + 1}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + def _make_stage(self, planes, num_blocks, num_se, num_conv_branches): + strides = [2] + [1] * (num_blocks - 1) + if num_se > num_blocks: + raise ValueError('Number of SE blocks cannot ' + 'exceed number of layers.') + blocks = [] + for i in range(num_blocks): + use_se = False + if i >= (num_blocks - num_se): + use_se = True + + blocks.append( + # Depthwise conv + MobileOneBlock( + in_channels=self.in_planes, + out_channels=self.in_planes, + kernel_size=3, + num_convs=num_conv_branches, + stride=strides[i], + padding=1, + groups=self.in_planes, + se_cfg=self.se_cfg if use_se else None, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy)) + + blocks.append( + # Pointwise conv + MobileOneBlock( + in_channels=self.in_planes, + out_channels=planes, + kernel_size=1, + num_convs=num_conv_branches, + stride=1, + padding=0, + se_cfg=self.se_cfg if use_se else None, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy)) + + self.in_planes = planes + + return Sequential(*blocks) + + def forward(self, x): + x = self.stage0(x) + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stage0.eval() + for param in self.stage0.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = getattr(self, f'stage{i+1}') + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + """switch the mobile to train mode or not.""" + super(MobileOne, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + """switch the model to deploy mode, which has smaller amount of + parameters and calculations.""" + for m in self.modules(): + if isinstance(m, MobileOneBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/mobilevit.py b/mmpretrain/models/backbones/mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4043fe46049a4d1bddecc6b7b3768236318e82 --- /dev/null +++ b/mmpretrain/models/backbones/mobilevit.py @@ -0,0 +1,431 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Callable, Optional, Sequence + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer +from torch import nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .mobilenet_v2 import InvertedResidual +from .vision_transformer import TransformerEncoderLayer + + +class MobileVitBlock(nn.Module): + """MobileViT block. + + According to the paper, the MobileViT block has a local representation. + a transformer-as-convolution layer which consists of a global + representation with unfolding and folding, and a final fusion layer. + + Args: + in_channels (int): Number of input image channels. + transformer_dim (int): Number of transformer channels. + ffn_dim (int): Number of ffn channels in transformer block. + out_channels (int): Number of channels in output. + conv_ksize (int): Conv kernel size in local representation + and fusion. Defaults to 3. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Defaults to dict(type='Swish'). + num_transformer_blocks (int): Number of transformer blocks in + a MobileViT block. Defaults to 2. + patch_size (int): Patch size for unfolding and folding. + Defaults to 2. + num_heads (int): Number of heads in global representation. + Defaults to 4. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + no_fusion (bool): Whether to remove the fusion layer. + Defaults to False. + transformer_norm_cfg (dict, optional): Config dict for normalization + layer in transformer. Defaults to dict(type='LN'). + """ + + def __init__( + self, + in_channels: int, + transformer_dim: int, + ffn_dim: int, + out_channels: int, + conv_ksize: int = 3, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='Swish'), + num_transformer_blocks: int = 2, + patch_size: int = 2, + num_heads: int = 4, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + no_fusion: bool = False, + transformer_norm_cfg: Callable = dict(type='LN'), + ): + super(MobileVitBlock, self).__init__() + + self.local_rep = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + padding=int((conv_ksize - 1) / 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=in_channels, + out_channels=transformer_dim, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=None, + act_cfg=None), + ) + + global_rep = [ + TransformerEncoderLayer( + embed_dims=transformer_dim, + num_heads=num_heads, + feedforward_channels=ffn_dim, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + qkv_bias=True, + act_cfg=dict(type='Swish'), + norm_cfg=transformer_norm_cfg) + for _ in range(num_transformer_blocks) + ] + global_rep.append( + build_norm_layer(transformer_norm_cfg, transformer_dim)[1]) + self.global_rep = nn.Sequential(*global_rep) + + self.conv_proj = ConvModule( + in_channels=transformer_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if no_fusion: + self.conv_fusion = None + else: + self.conv_fusion = ConvModule( + in_channels=in_channels + out_channels, + out_channels=out_channels, + kernel_size=conv_ksize, + padding=int((conv_ksize - 1) / 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.patch_size = (patch_size, patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + # Local representation + x = self.local_rep(x) + + # Unfold (feature map -> patches) + patch_h, patch_w = self.patch_size + B, C, H, W = x.shape + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil( + W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa + num_patches = num_patch_h * num_patch_w # N + interpolate = False + if new_h != H or new_w != W: + # Note: Padding can be done, but then it needs to be handled in attention function. # noqa + x = F.interpolate( + x, size=(new_h, new_w), mode='bilinear', align_corners=False) + interpolate = True + + # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] + x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, + patch_w).transpose(1, 2) + # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa + x = x.reshape(B, C, num_patches, + self.patch_area).transpose(1, 3).reshape( + B * self.patch_area, num_patches, -1) + + # Global representations + x = self.global_rep(x) + + # Fold (patch -> feature map) + # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] + x = x.contiguous().view(B, self.patch_area, num_patches, -1) + x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, + patch_h, patch_w) + # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa + x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, + num_patch_w * patch_w) + if interpolate: + x = F.interpolate( + x, size=(H, W), mode='bilinear', align_corners=False) + + x = self.conv_proj(x) + if self.conv_fusion is not None: + x = self.conv_fusion(torch.cat((shortcut, x), dim=1)) + return x + + +@MODELS.register_module() +class MobileViT(BaseBackbone): + """MobileViT backbone. + + A PyTorch implementation of : `MobileViT: Light-weight, General-purpose, + and Mobile-friendly Vision Transformer `_ + + Modified from the `official repo + `_ + and `timm + `_. + + Args: + arch (str | List[list]): Architecture of MobileViT. + + - If a string, choose from "small", "x_small" and "xx_small". + + - If a list, every item should be also a list, and the first item + of the sub-list can be chosen from "moblienetv2" and "mobilevit", + which indicates the type of this layer sequence. If "mobilenetv2", + the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer` + (except ``in_channels``) and if "mobilevit", the other items are + the arguments of :attr:`~MobileViT.make_mobilevit_layer` + (except ``in_channels``). + + Defaults to "small". + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Channels of stem layer. Defaults to 16. + last_exp_factor (int): Channels expand factor of last layer. + Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Defaults to (4, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Defaults to dict(type='Swish'). + init_cfg (dict, optional): Initialization config dict. + """ # noqa + + # Parameters to build layers. The first param is the type of layer. + # For `mobilenetv2` layer, the rest params from left to right are: + # out channels, stride, num of blocks, expand_ratio. + # For `mobilevit` layer, the rest params from left to right are: + # out channels, stride, transformer_channels, ffn channels, + # num of transformer blocks, expand_ratio. + arch_settings = { + 'small': [ + ['mobilenetv2', 32, 1, 1, 4], + ['mobilenetv2', 64, 2, 3, 4], + ['mobilevit', 96, 2, 144, 288, 2, 4], + ['mobilevit', 128, 2, 192, 384, 4, 4], + ['mobilevit', 160, 2, 240, 480, 3, 4], + ], + 'x_small': [ + ['mobilenetv2', 32, 1, 1, 4], + ['mobilenetv2', 48, 2, 3, 4], + ['mobilevit', 64, 2, 96, 192, 2, 4], + ['mobilevit', 80, 2, 120, 240, 4, 4], + ['mobilevit', 96, 2, 144, 288, 3, 4], + ], + 'xx_small': [ + ['mobilenetv2', 16, 1, 1, 2], + ['mobilenetv2', 24, 2, 3, 2], + ['mobilevit', 48, 2, 64, 128, 2, 2], + ['mobilevit', 64, 2, 80, 160, 4, 2], + ['mobilevit', 80, 2, 96, 192, 3, 2], + ] + } + + def __init__(self, + arch='small', + in_channels=3, + stem_channels=16, + last_exp_factor=4, + out_indices=(4, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='Swish'), + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileViT, self).__init__(init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a list.' + arch = self.arch_settings[arch] + + self.arch = arch + self.num_stages = len(arch) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + if frozen_stages not in range(-1, self.num_stages): + raise ValueError('frozen_stages must be in range(-1, ' + f'{self.num_stages}). ' + f'But received {frozen_stages}') + self.frozen_stages = frozen_stages + + _make_layer_func = { + 'mobilenetv2': self.make_mobilenetv2_layer, + 'mobilevit': self.make_mobilevit_layer, + } + + self.stem = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + in_channels = stem_channels + layers = [] + for i, layer_settings in enumerate(arch): + layer_type, settings = layer_settings[0], layer_settings[1:] + layer, out_channels = _make_layer_func[layer_type](in_channels, + *settings) + layers.append(layer) + in_channels = out_channels + self.layers = nn.Sequential(*layers) + + self.conv_1x1_exp = ConvModule( + in_channels=in_channels, + out_channels=last_exp_factor * in_channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + @staticmethod + def make_mobilevit_layer(in_channels, + out_channels, + stride, + transformer_dim, + ffn_dim, + num_transformer_blocks, + expand_ratio=4): + """Build mobilevit layer, which consists of one InvertedResidual and + one MobileVitBlock. + + Args: + in_channels (int): The input channels. + out_channels (int): The output channels. + stride (int): The stride of the first 3x3 convolution in the + ``InvertedResidual`` layers. + transformer_dim (int): The channels of the transformer layers. + ffn_dim (int): The mid-channels of the feedforward network in + transformer layers. + num_transformer_blocks (int): The number of transformer blocks. + expand_ratio (int): adjusts number of channels of the hidden layer + in ``InvertedResidual`` by this amount. Defaults to 4. + """ + layer = [] + layer.append( + InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + act_cfg=dict(type='Swish'), + )) + layer.append( + MobileVitBlock( + in_channels=out_channels, + transformer_dim=transformer_dim, + ffn_dim=ffn_dim, + out_channels=out_channels, + num_transformer_blocks=num_transformer_blocks, + )) + return nn.Sequential(*layer), out_channels + + @staticmethod + def make_mobilenetv2_layer(in_channels, + out_channels, + stride, + num_blocks, + expand_ratio=4): + """Build mobilenetv2 layer, which consists of several InvertedResidual + layers. + + Args: + in_channels (int): The input channels. + out_channels (int): The output channels. + stride (int): The stride of the first 3x3 convolution in the + ``InvertedResidual`` layers. + num_blocks (int): The number of ``InvertedResidual`` blocks. + expand_ratio (int): adjusts number of channels of the hidden layer + in ``InvertedResidual`` by this amount. Defaults to 4. + """ + layer = [] + for i in range(num_blocks): + stride = stride if i == 0 else 1 + + layer.append( + InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + act_cfg=dict(type='Swish'), + )) + in_channels = out_channels + return nn.Sequential(*layer), out_channels + + def _freeze_stages(self): + for i in range(0, self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileViT, self).train(mode) + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + x = self.conv_1x1_exp(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/mvit.py b/mmpretrain/models/backbones/mvit.py new file mode 100644 index 0000000000000000000000000000000000000000..68aee97ddf3077ca58e488f38e9d9422b171d691 --- /dev/null +++ b/mmpretrain/models/backbones/mvit.py @@ -0,0 +1,700 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import to_2tuple + +from ..builder import BACKBONES +from ..utils import resize_pos_embed +from .base_backbone import BaseBackbone + + +def resize_decomposed_rel_pos(rel_pos, q_size, k_size): + """Get relative positional embeddings according to the relative positions + of query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + resized = F.interpolate( + # (L, C) -> (1, C, L) + rel_pos.transpose(0, 1).unsqueeze(0), + size=max_rel_dist, + mode='linear', + ) + # (1, C, L) -> (L, C) + resized = resized.squeeze(0).transpose(0, 1) + else: + resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_h_ratio = max(k_size / q_size, 1.0) + k_h_ratio = max(q_size / k_size, 1.0) + q_coords = torch.arange(q_size)[:, None] * q_h_ratio + k_coords = torch.arange(k_size)[None, :] * k_h_ratio + relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio + + return resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, + q, + q_shape, + k_shape, + rel_pos_h, + rel_pos_w, + has_cls_token=False): + """Spatial Relative Positional Embeddings.""" + sp_idx = 1 if has_cls_token else 0 + B, num_heads, _, C = q.shape + q_h, q_w = q_shape + k_h, k_w = k_shape + + Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h) + Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w) + + r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C) + rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) + rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) + rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :] + + attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + attn_map += rel_pos_embed + attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class MLP(BaseModule): + """Two-layer multilayer perceptron. + + Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows + different input and output channel numbers. + + Args: + in_channels (int): The number of input channels. + hidden_channels (int, optional): The number of hidden layer channels. + If None, same as the ``in_channels``. Defaults to None. + out_channels (int, optional): The number of output channels. If None, + same as the ``in_channels``. Defaults to None. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + act_cfg=dict(type='GELU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Linear(hidden_channels, out_channels) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def attention_pool(x: torch.Tensor, + pool: nn.Module, + in_size: tuple, + norm: Optional[nn.Module] = None): + """Pooling the feature tokens. + + Args: + x (torch.Tensor): The input tensor, should be with shape + ``(B, num_heads, L, C)`` or ``(B, L, C)``. + pool (nn.Module): The pooling module. + in_size (Tuple[int]): The shape of the input feature map. + norm (nn.Module, optional): The normalization module. + Defaults to None. + """ + ndim = x.ndim + if ndim == 4: + B, num_heads, L, C = x.shape + elif ndim == 3: + num_heads = 1 + B, L, C = x.shape + else: + raise RuntimeError(f'Unsupported input dimension {x.shape}') + + H, W = in_size + assert L == H * W + + # (B, num_heads, H*W, C) -> (B*num_heads, C, H, W) + x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous() + x = pool(x) + out_size = x.shape[-2:] + + # (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C) + x = x.reshape(B, num_heads, C, -1).transpose(2, 3) + + if norm is not None: + x = norm(x) + + if ndim == 3: + x = x.squeeze(1) + + return x, out_size + + +class MultiScaleAttention(BaseModule): + """Multiscale Multi-head Attention block. + + Args: + in_dims (int): Number of input channels. + out_dims (int): Number of output channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key and + value. Defaults to True. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='LN')``. + pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + stride_q (int): stride size for q pooling layer. Defaults to 1. + stride_kv (int): stride size for kv pooling layer. Defaults to 1. + rel_pos_spatial (bool): Whether to enable the spatial relative + position embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + input_size (Tuple[int], optional): The input resolution, necessary + if enable the ``rel_pos_spatial``. Defaults to None. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__(self, + in_dims, + out_dims, + num_heads, + qkv_bias=True, + norm_cfg=dict(type='LN'), + pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + rel_pos_spatial=False, + residual_pooling=True, + input_size=None, + rel_pos_zero_init=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.in_dims = in_dims + self.out_dims = out_dims + + head_dim = out_dims // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(out_dims, out_dims) + + # qkv pooling + pool_padding = [k // 2 for k in pool_kernel] + pool_dims = out_dims // num_heads + + def build_pooling(stride): + pool = nn.Conv2d( + pool_dims, + pool_dims, + pool_kernel, + stride=stride, + padding=pool_padding, + groups=pool_dims, + bias=False, + ) + norm = build_norm_layer(norm_cfg, pool_dims)[1] + return pool, norm + + self.pool_q, self.norm_q = build_pooling(stride_q) + self.pool_k, self.norm_k = build_pooling(stride_kv) + self.pool_v, self.norm_v = build_pooling(stride_kv) + + self.residual_pooling = residual_pooling + + self.rel_pos_spatial = rel_pos_spatial + self.rel_pos_zero_init = rel_pos_zero_init + if self.rel_pos_spatial: + # initialize relative positional embeddings + assert input_size[0] == input_size[1] + + size = input_size[0] + rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim)) + + def init_weights(self): + """Weight initialization.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress rel_pos_zero_init if use pretrained model. + return + + if not self.rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x, in_size): + """Forward the MultiScaleAttention.""" + B, N, _ = x.shape # (B, H*W, C) + + # qkv: (B, H*W, 3, num_heads, C) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1) + # q, k, v: (B, num_heads, H*W, C) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + + q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q) + k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k) + v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_spatial: + attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape, + self.rel_pos_h, self.rel_pos_w) + + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + # (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C') + x = x.transpose(1, 2).reshape(B, -1, self.out_dims) + x = self.proj(x) + + return x, q_shape + + +class MultiScaleBlock(BaseModule): + """Multiscale Transformer blocks. + + Args: + in_dims (int): Number of input channels. + out_dims (int): Number of output channels. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of hidden dimensions in MLP layers. + Defaults to 4.0. + qkv_bias (bool): If True, add a learnable bias to query, key and + value. Defaults to True. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='LN')``. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + stride_q (int): stride size for q pooling layer. Defaults to 1. + stride_kv (int): stride size for kv pooling layer. Defaults to 1. + rel_pos_spatial (bool): Whether to enable the spatial relative + position embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in + attention layers. If False, multiply it in MLP layers. + Defaults to True. + input_size (Tuple[int], optional): The input resolution, necessary + if enable the ``rel_pos_spatial``. Defaults to None. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__( + self, + in_dims, + out_dims, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + qkv_pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + rel_pos_spatial=True, + residual_pooling=True, + dim_mul_in_attention=True, + input_size=None, + rel_pos_zero_init=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_dims = in_dims + self.out_dims = out_dims + self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] + self.dim_mul_in_attention = dim_mul_in_attention + + attn_dims = out_dims if dim_mul_in_attention else in_dims + self.attn = MultiScaleAttention( + in_dims, + attn_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + pool_kernel=qkv_pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + rel_pos_spatial=rel_pos_spatial, + residual_pooling=residual_pooling, + input_size=input_size, + rel_pos_zero_init=rel_pos_zero_init) + self.drop_path = DropPath( + drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1] + + self.mlp = MLP( + in_channels=attn_dims, + hidden_channels=int(attn_dims * mlp_ratio), + out_channels=out_dims, + act_cfg=act_cfg) + + if in_dims != out_dims: + self.proj = nn.Linear(in_dims, out_dims) + else: + self.proj = None + + if stride_q > 1: + kernel_skip = stride_q + 1 + padding_skip = int(kernel_skip // 2) + self.pool_skip = nn.MaxPool2d( + kernel_skip, stride_q, padding_skip, ceil_mode=False) + + if input_size is not None: + input_size = to_2tuple(input_size) + out_size = [size // stride_q for size in input_size] + self.init_out_size = out_size + else: + self.init_out_size = None + else: + self.pool_skip = None + self.init_out_size = input_size + + def forward(self, x, in_size): + x_norm = self.norm1(x) + x_attn, out_size = self.attn(x_norm, in_size) + + if self.dim_mul_in_attention and self.proj is not None: + skip = self.proj(x_norm) + else: + skip = x + + if self.pool_skip is not None: + skip, _ = attention_pool(skip, self.pool_skip, in_size) + + x = skip + self.drop_path(x_attn) + x_norm = self.norm2(x) + x_mlp = self.mlp(x_norm) + + if not self.dim_mul_in_attention and self.proj is not None: + skip = self.proj(x_norm) + else: + skip = x + + x = skip + self.drop_path(x_mlp) + + return x, out_size + + +@BACKBONES.register_module() +class MViT(BaseBackbone): + """Multi-scale ViT v2. + + A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers + for Classification and Detection `_ + + Inspiration from `the official implementation + `_ and `the detectron2 + implementation `_ + + Args: + arch (str | dict): MViT architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of layers. + - **num_heads** (int): The number of heads in attention + modules of the initial layer. + - **downscale_indices** (List[int]): The layer indices to downscale + the feature map. + + Defaults to 'base'. + img_size (int): The expected input image shape. Defaults to 224. + in_channels (int): The num of input channels. Defaults to 3. + out_scales (int | Sequence[int]): The output scale indices. + They should not exceed the length of ``downscale_indices``. + Defaults to -1, which means the last scale. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embedding vector resize. Defaults to "bicubic". + pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + dim_mul (int): The magnification for ``embed_dims`` in the downscale + layers. Defaults to 2. + head_mul (int): The magnification for ``num_heads`` in the downscale + layers. Defaults to 2. + adaptive_kv_stride (int): The stride size for kv pooling in the initial + layer. Defaults to 4. + rel_pos_spatial (bool): Whether to enable the spatial relative position + embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in + attention layers. If False, multiply it in MLP layers. + Defaults to True. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + mlp_ratio (float): Ratio of hidden dimensions in MLP layers. + Defaults to 4.0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN', eps=1e-6)``. + patch_cfg (dict): Config dict for the patch embedding layer. + Defaults to ``dict(kernel_size=7, stride=4, padding=3)``. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_backbone + >>> + >>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3]) + >>> model = build_backbone(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for i, output in enumerate(outputs): + >>> print(f'scale{i}: {output.shape}') + scale0: torch.Size([1, 96, 56, 56]) + scale1: torch.Size([1, 192, 28, 28]) + scale2: torch.Size([1, 384, 14, 14]) + scale3: torch.Size([1, 768, 7, 7]) + """ + arch_zoo = { + 'tiny': { + 'embed_dims': 96, + 'num_layers': 10, + 'num_heads': 1, + 'downscale_indices': [1, 3, 8] + }, + 'small': { + 'embed_dims': 96, + 'num_layers': 16, + 'num_heads': 1, + 'downscale_indices': [1, 3, 14] + }, + 'base': { + 'embed_dims': 96, + 'num_layers': 24, + 'num_heads': 1, + 'downscale_indices': [2, 5, 21] + }, + 'large': { + 'embed_dims': 144, + 'num_layers': 48, + 'num_heads': 2, + 'downscale_indices': [2, 8, 44] + }, + } + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + in_channels=3, + out_scales=-1, + drop_path_rate=0., + use_abs_pos_embed=False, + interpolate_mode='bicubic', + pool_kernel=(3, 3), + dim_mul=2, + head_mul=2, + adaptive_kv_stride=4, + rel_pos_spatial=True, + residual_pooling=True, + dim_mul_in_attention=True, + rel_pos_zero_init=False, + mlp_ratio=4., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + patch_cfg=dict(kernel_size=7, stride=4, padding=3), + init_cfg=None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'downscale_indices' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.num_heads = self.arch_settings['num_heads'] + self.downscale_indices = self.arch_settings['downscale_indices'] + self.num_scales = len(self.downscale_indices) + 1 + self.stage_indices = { + index - 1: i + for i, index in enumerate(self.downscale_indices) + } + self.stage_indices[self.num_layers - 1] = self.num_scales - 1 + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + + if isinstance(out_scales, int): + out_scales = [out_scales] + assert isinstance(out_scales, Sequence), \ + f'"out_scales" must by a sequence or int, ' \ + f'get {type(out_scales)} instead.' + for i, index in enumerate(out_scales): + if index < 0: + out_scales[i] = self.num_scales + index + assert 0 <= out_scales[i] <= self.num_scales, \ + f'Invalid out_scales {index}' + self.out_scales = sorted(list(out_scales)) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + # Set absolute position embedding + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.blocks = ModuleList() + out_dims_list = [self.embed_dims] + num_heads = self.num_heads + stride_kv = adaptive_kv_stride + input_size = self.patch_resolution + for i in range(self.num_layers): + if i in self.downscale_indices: + num_heads *= head_mul + stride_q = 2 + stride_kv = max(stride_kv // 2, 1) + else: + stride_q = 1 + + # Set output embed_dims + if dim_mul_in_attention and i in self.downscale_indices: + # multiply embed_dims in downscale layers. + out_dims = out_dims_list[-1] * dim_mul + elif not dim_mul_in_attention and i + 1 in self.downscale_indices: + # multiply embed_dims before downscale layers. + out_dims = out_dims_list[-1] * dim_mul + else: + out_dims = out_dims_list[-1] + + attention_block = MultiScaleBlock( + in_dims=out_dims_list[-1], + out_dims=out_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_cfg=norm_cfg, + qkv_pool_kernel=pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + rel_pos_spatial=rel_pos_spatial, + residual_pooling=residual_pooling, + dim_mul_in_attention=dim_mul_in_attention, + input_size=input_size, + rel_pos_zero_init=rel_pos_zero_init) + self.blocks.append(attention_block) + + input_size = attention_block.init_out_size + out_dims_list.append(out_dims) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + norm_layer = build_norm_layer(norm_cfg, out_dims)[1] + self.add_module(f'norm{stage_index}', norm_layer) + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + """Forward the MViT.""" + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + + outs = [] + for i, block in enumerate(self.blocks): + x, patch_resolution = block(x, patch_resolution) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + B, _, C = x.shape + x = getattr(self, f'norm{stage_index}')(x) + out = x.transpose(1, 2).reshape(B, C, *patch_resolution) + outs.append(out.contiguous()) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/poolformer.py b/mmpretrain/models/backbones/poolformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ad67043dbeb0ce6969c2770853342b30df2a74 --- /dev/null +++ b/mmpretrain/models/backbones/poolformer.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class PatchEmbed(nn.Module): + """Patch Embedding module implemented by a layer of convolution. + + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + Args: + patch_size (int): Patch size of the patch embedding. Defaults to 16. + stride (int): Stride of the patch embedding. Defaults to 16. + padding (int): Padding of the patch embedding. Defaults to 0. + in_chans (int): Input channels. Defaults to 3. + embed_dim (int): Output dimension of the patch embedding. + Defaults to 768. + norm_layer (module): Normalization module. Defaults to None (not use). + """ + + def __init__(self, + patch_size=16, + stride=16, + padding=0, + in_chans=3, + embed_dim=768, + norm_layer=None): + super().__init__() + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class Pooling(nn.Module): + """Pooling module. + + Args: + pool_size (int): Pooling size. Defaults to 3. + """ + + def __init__(self, pool_size=3): + super().__init__() + self.pool = nn.AvgPool2d( + pool_size, + stride=1, + padding=pool_size // 2, + count_include_pad=False) + + def forward(self, x): + return self.pool(x) - x + + +class Mlp(nn.Module): + """Mlp implemented by with 1*1 convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PoolFormerBlock(BaseModule): + """PoolFormer Block. + + Args: + dim (int): Embedding dim. + pool_size (int): Pooling size. Defaults to 3. + mlp_ratio (float): Mlp expansion ratio. Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='GN', num_groups=1)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-5. + """ + + def __init__(self, + dim, + pool_size=3, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = Pooling(pool_size=pool_size) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + +def basic_blocks(dim, + index, + layers, + pool_size=3, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + layer_scale_init_value=1e-5): + """ + generate PoolFormer blocks for a stage + return: PoolFormer blocks + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + blocks.append( + PoolFormerBlock( + dim, + pool_size=pool_size, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + )) + blocks = nn.Sequential(*blocks) + + return blocks + + +@MODELS.register_module() +class PoolFormer(BaseBackbone): + """PoolFormer. + + A PyTorch implementation of PoolFormer introduced by: + `MetaFormer is Actually What You Need for Vision `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``PoolFormer.arch_settings``. And if dict, it + should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - mlp_ratios (list[int]): Expansion ratio of MLPs. + - layer_scale_init_value (float): Init value for Layer Scale. + + Defaults to 'S12'. + + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + in_patch_size (int): The patch size of input image patch embedding. + Defaults to 7. + in_stride (int): The stride of input image patch embedding. + Defaults to 4. + in_pad (int): The padding of input image patch embedding. + Defaults to 2. + down_patch_size (int): The patch size of downsampling patch embedding. + Defaults to 3. + down_stride (int): The stride of downsampling patch embedding. + Defaults to 2. + down_pad (int): The padding of downsampling patch embedding. + Defaults to 1. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims, --mlp_ratios: + # embedding dims and mlp ratios for the four stages + # --downsamples: flags to apply downsampling or not in four blocks + arch_settings = { + 's12': { + 'layers': [2, 2, 6, 2], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's24': { + 'layers': [4, 4, 12, 4], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm48': { + 'layers': [8, 8, 24, 8], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + } + + def __init__(self, + arch='s12', + pool_size=3, + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., + drop_path_rate=0., + out_indices=-1, + frozen_stages=0, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + layers = arch['layers'] + embed_dims = arch['embed_dims'] + mlp_ratios = arch['mlp_ratios'] \ + if 'mlp_ratios' in arch else [4, 4, 4, 4] + layer_scale_init_value = arch['layer_scale_init_value'] \ + if 'layer_scale_init_value' in arch else 1e-5 + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chans=3, + embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + stage = basic_blocks( + embed_dims[i], + i, + layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], + embed_dim=embed_dims[i + 1])) + + self.network = nn.ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 7 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, + embed_dims[(i_layer + 1) // 2])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(PoolFormer, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/regnet.py b/mmpretrain/models/backbones/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..85dbdef0bfeb607ecddff1d68d1cf405b61bea65 --- /dev/null +++ b/mmpretrain/models/backbones/regnet.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResNet +from .resnext import Bottleneck + + +@MODELS.register_module() +class RegNet(ResNet): + """RegNet backbone. + + More details can be found in `paper `_ . + + Args: + arch (dict): The parameter of RegNets. + - w0 (int): initial width + - wa (float): slope of width + - wm (float): quantization parameter to quantize the width + - depth (int): depth of the backbone + - group_w (int): width of group + - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck. + strides (Sequence[int]): Strides of the first block of each stage. + base_channels (int): Base channels after stem layer. + in_channels (int): Number of input image channels. Default: 3. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: "pytorch". + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Default: -1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import RegNet + >>> import torch + >>> self = RegNet( + arch=dict( + w0=88, + wa=26.31, + wm=2.25, + group_w=48, + depth=25, + bot_mul=1.0)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 96, 8, 8) + (1, 192, 4, 4) + (1, 432, 2, 2) + (1, 1008, 1, 1) + """ + arch_settings = { + 'regnetx_400mf': + dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0), + 'regnetx_800mf': + dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0), + 'regnetx_1.6gf': + dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0), + 'regnetx_3.2gf': + dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0), + 'regnetx_4.0gf': + dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0), + 'regnetx_6.4gf': + dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0), + 'regnetx_8.0gf': + dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0), + 'regnetx_12gf': + dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0), + } + + def __init__(self, + arch, + in_channels=3, + stem_channels=32, + base_channels=32, + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + + # Generate RegNet parameters first + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the' \ + ' arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + widths, num_stages = self.generate_regnet( + arch['w0'], + arch['wa'], + arch['wm'], + arch['depth'], + ) + # Convert to per stage format + stage_widths, stage_blocks = self.get_stages_from_blocks(widths) + # Generate group widths and bot muls + group_widths = [arch['group_w'] for _ in range(num_stages)] + self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)] + # Adjust the compatibility of stage_widths and group_widths + stage_widths, group_widths = self.adjust_width_group( + stage_widths, self.bottleneck_ratio, group_widths) + + # Group params by stage + self.stage_widths = stage_widths + self.group_widths = group_widths + self.depth = sum(stage_blocks) + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + if self.deep_stem: + raise NotImplementedError( + 'deep_stem has not been implemented for RegNet') + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.stage_blocks = stage_blocks[:num_stages] + + self._make_stem_layer(in_channels, stem_channels) + + _in_channels = stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + group_width = self.group_widths[i] + width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i])) + stage_groups = width // group_width + + res_layer = self.make_res_layer( + block=Bottleneck, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=self.stage_widths[i], + expansion=1, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + base_channels=self.stage_widths[i], + groups=stage_groups, + width_per_group=group_width) + _in_channels = self.stage_widths[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = stage_widths[-1] + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def generate_regnet(self, + initial_width, + width_slope, + width_parameter, + depth, + divisor=8): + """Generates per block width from RegNet parameters. + + Args: + initial_width ([int]): Initial width of the backbone + width_slope ([float]): Slope of the quantized linear function + width_parameter ([int]): Parameter used to quantize the width. + depth ([int]): Depth of the backbone. + divisor (int): The divisor of channels. Defaults to 8. + + Returns: + tuple: tuple containing: + - list: Widths of each stage. + - int: The number of stages. + """ + assert width_slope >= 0 + assert initial_width > 0 + assert width_parameter > 1 + assert initial_width % divisor == 0 + widths_cont = np.arange(depth) * width_slope + initial_width + ks = np.round( + np.log(widths_cont / initial_width) / np.log(width_parameter)) + widths = initial_width * np.power(width_parameter, ks) + widths = np.round(np.divide(widths, divisor)) * divisor + num_stages = len(np.unique(widths)) + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages + + @staticmethod + def quantize_float(number, divisor): + """Converts a float to closest non-zero int divisible by divior. + + Args: + number (int): Original number to be quantized. + divisor (int): Divisor used to quantize the number. + + Returns: + int: quantized number that is divisible by devisor. + """ + return int(round(number / divisor) * divisor) + + def adjust_width_group(self, widths, bottleneck_ratio, groups): + """Adjusts the compatibility of widths and groups. + + Args: + widths (list[int]): Width of each stage. + bottleneck_ratio (float): Bottleneck ratio. + groups (int): number of groups in each stage + + Returns: + tuple(list): The adjusted widths and groups of each stage. + """ + bottleneck_width = [ + int(w * b) for w, b in zip(widths, bottleneck_ratio) + ] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)] + bottleneck_width = [ + self.quantize_float(w_bot, g) + for w_bot, g in zip(bottleneck_width, groups) + ] + widths = [ + int(w_bot / b) + for w_bot, b in zip(bottleneck_width, bottleneck_ratio) + ] + return widths, groups + + def get_stages_from_blocks(self, widths): + """Gets widths/stage_blocks of network at each stage. + + Args: + widths (list[int]): Width in each stage. + + Returns: + tuple(list): width and depth of each stage + """ + width_diff = [ + width != width_prev + for width, width_prev in zip(widths + [0], [0] + widths) + ] + stage_widths = [ + width for width, diff in zip(widths, width_diff[:-1]) if diff + ] + stage_blocks = np.diff([ + depth for depth, diff in zip(range(len(width_diff)), width_diff) + if diff + ]).tolist() + return stage_widths, stage_blocks + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/replknet.py b/mmpretrain/models/backbones/replknet.py new file mode 100644 index 0000000000000000000000000000000000000000..4dce4154fbe1d95806eec118b69ff70f0d74c1c6 --- /dev/null +++ b/mmpretrain/models/backbones/replknet.py @@ -0,0 +1,668 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +def conv_bn(in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1, + norm_cfg=dict(type='BN')): + """Construct a sequential conv and bn. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the convolution. + stride (int): stride of the convolution. + padding (int): stride of the convolution. + groups (int): groups of the convolution. + dilation (int): dilation of the convolution. Default to 1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + + Returns: + nn.Sequential(): A conv layer and a batch norm layer. + """ + if padding is None: + padding = kernel_size // 2 + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False)) + result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1]) + return result + + +def conv_bn_relu(in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1): + """Construct a sequential conv, bn and relu. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the convolution. + stride (int): stride of the convolution. + padding (int): stride of the convolution. + groups (int): groups of the convolution. + dilation (int): dilation of the convolution. Default to 1. + + Returns: + nn.Sequential(): A conv layer, batch norm layer and a relu function. + """ + + if padding is None: + padding = kernel_size // 2 + result = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + dilation=dilation) + result.add_module('nonlinear', nn.ReLU()) + return result + + +def fuse_bn(conv, bn): + """Fuse the parameters in a branch with a conv and bn. + + Args: + conv (nn.Conv2d): The convolution module to fuse. + bn (nn.BatchNorm2d): The batch normalization to fuse. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ReparamLargeKernelConv(BaseModule): + """Super large kernel implemented by with large convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the large convolution. + stride (int): stride of the large convolution. + groups (int): groups of the large convolution. + small_kernel (int): kernel_size of the small convolution. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + small_kernel, + small_kernel_merged=False, + init_cfg=None): + super(ReparamLargeKernelConv, self).__init__(init_cfg) + self.kernel_size = kernel_size + self.small_kernel = small_kernel + self.small_kernel_merged = small_kernel_merged + # We assume the conv does not change the feature map size, + # so padding = k//2. + # Otherwise, you may configure padding as you wish, + # and change the padding of small_conv accordingly. + padding = kernel_size // 2 + if small_kernel_merged: + self.lkb_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=True) + else: + self.lkb_origin = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups) + if small_kernel is not None: + assert small_kernel <= kernel_size + self.small_conv = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=small_kernel, + stride=stride, + padding=small_kernel // 2, + groups=groups, + dilation=1) + + def forward(self, inputs): + if hasattr(self, 'lkb_reparam'): + out = self.lkb_reparam(inputs) + else: + out = self.lkb_origin(inputs) + if hasattr(self, 'small_conv'): + out += self.small_conv(inputs) + return out + + def get_equivalent_kernel_bias(self): + eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, 'small_conv'): + small_k, small_b = fuse_bn(self.small_conv.conv, + self.small_conv.bn) + eq_b += small_b + # add to the central part + eq_k += nn.functional.pad( + small_k, [(self.kernel_size - self.small_kernel) // 2] * 4) + return eq_k, eq_b + + def merge_kernel(self): + """Switch the model structure from training mode to deployment mode.""" + if self.small_kernel_merged: + return + eq_k, eq_b = self.get_equivalent_kernel_bias() + self.lkb_reparam = nn.Conv2d( + in_channels=self.lkb_origin.conv.in_channels, + out_channels=self.lkb_origin.conv.out_channels, + kernel_size=self.lkb_origin.conv.kernel_size, + stride=self.lkb_origin.conv.stride, + padding=self.lkb_origin.conv.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.lkb_origin.conv.groups, + bias=True) + + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__('lkb_origin') + if hasattr(self, 'small_conv'): + self.__delattr__('small_conv') + + self.small_kernel_merged = True + + +class ConvFFN(BaseModule): + """Mlp implemented by with 1*1 convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + + Args: + in_channels (int): Dimension of input features. + internal_channels (int): Dimension of hidden features. + out_channels (int): Dimension of output features. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + internal_channels, + out_channels, + drop_path, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(ConvFFN, self).__init__(init_cfg) + self.drop_path = DropPath( + drop_prob=drop_path) if drop_path > 0. else nn.Identity() + self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1] + self.pw1 = conv_bn( + in_channels=in_channels, + out_channels=internal_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.pw2 = conv_bn( + in_channels=internal_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.nonlinear = build_activation_layer(act_cfg) + + def forward(self, x): + out = self.preffn_bn(x) + out = self.pw1(out) + out = self.nonlinear(out) + out = self.pw2(out) + return x + self.drop_path(out) + + +class RepLKBlock(BaseModule): + """RepLKBlock for RepLKNet backbone. + + Args: + in_channels (int): The input channels of the block. + dw_channels (int): The intermediate channels of the block, + i.e., input channels of the large kernel convolution. + block_lk_size (int): size of the super large kernel. Defaults: 31. + small_kernel (int): size of the parallel small kernel. Defaults: 5. + drop_path (float): Stochastic depth rate. Defaults: 0. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + act_cfg (dict): Config dict for activation layer. + Default to ``dict(type='ReLU')``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default to None + """ + + def __init__(self, + in_channels, + dw_channels, + block_lk_size, + small_kernel, + drop_path, + small_kernel_merged=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(RepLKBlock, self).__init__(init_cfg) + self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1) + self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1) + self.large_kernel = ReparamLargeKernelConv( + in_channels=dw_channels, + out_channels=dw_channels, + kernel_size=block_lk_size, + stride=1, + groups=dw_channels, + small_kernel=small_kernel, + small_kernel_merged=small_kernel_merged) + self.lk_nonlinear = build_activation_layer(act_cfg) + self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1] + self.drop_path = DropPath( + drop_prob=drop_path) if drop_path > 0. else nn.Identity() + # print('drop path:', self.drop_path) + + def forward(self, x): + out = self.prelkb_bn(x) + out = self.pw1(out) + out = self.large_kernel(out) + out = self.lk_nonlinear(out) + out = self.pw2(out) + return x + self.drop_path(out) + + +class RepLKNetStage(BaseModule): + """ + generate RepLKNet blocks for a stage + return: RepLKNet blocks + + Args: + channels (int): The input channels of the stage. + num_blocks (int): The number of blocks of the stage. + stage_lk_size (int): size of the super large kernel. Defaults: 31. + drop_path (float): Stochastic depth rate. Defaults: 0. + small_kernel (int): size of the parallel small kernel. Defaults: 5. + dw_ratio (float): The intermediate channels + expansion ratio of the block. Defaults: 1. + ffn_ratio (float): Mlp expansion ratio. Defaults to 4. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default to False. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + norm_intermediate_features (bool): Construct and config norm layer + or not. + Using True will normalize the intermediate features for + downstream dense prediction tasks. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default to None + """ + + def __init__( + self, + channels, + num_blocks, + stage_lk_size, + drop_path, + small_kernel, + dw_ratio=1, + ffn_ratio=4, + with_cp=False, # train with torch.utils.checkpoint to save memory + small_kernel_merged=False, + norm_intermediate_features=False, + norm_cfg=dict(type='BN'), + init_cfg=None): + super(RepLKNetStage, self).__init__(init_cfg) + self.with_cp = with_cp + blks = [] + for i in range(num_blocks): + block_drop_path = drop_path[i] if isinstance(drop_path, + list) else drop_path + # Assume all RepLK Blocks within a stage share the same lk_size. + # You may tune it on your own model. + replk_block = RepLKBlock( + in_channels=channels, + dw_channels=int(channels * dw_ratio), + block_lk_size=stage_lk_size, + small_kernel=small_kernel, + drop_path=block_drop_path, + small_kernel_merged=small_kernel_merged) + convffn_block = ConvFFN( + in_channels=channels, + internal_channels=int(channels * ffn_ratio), + out_channels=channels, + drop_path=block_drop_path) + blks.append(replk_block) + blks.append(convffn_block) + self.blocks = nn.ModuleList(blks) + if norm_intermediate_features: + self.norm = build_norm_layer(norm_cfg, channels)[1] + else: + self.norm = nn.Identity() + + def forward(self, x): + for blk in self.blocks: + if self.with_cp: + x = checkpoint.checkpoint(blk, x) # Save training memory + else: + x = blk(x) + return x + + +@MODELS.register_module() +class RepLKNet(BaseBackbone): + """RepLKNet backbone. + + A PyTorch impl of : + `Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs + `_ + + Args: + arch (str | dict): The parameter of RepLKNet. + If it's a dict, it should contain the following keys: + + - large_kernel_sizes (Sequence[int]): + Large kernel size in each stage. + - layers (Sequence[int]): Number of blocks in each stage. + - channels (Sequence[int]): Number of channels in each stage. + - small_kernel (int): size of the parallel small kernel. + - dw_ratio (float): The intermediate channels + expansion ratio of the block. + in_channels (int): Number of input image channels. Default to 3. + ffn_ratio (float): Mlp expansion ratio. Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Default to (3, ). + strides (Sequence[int]): Strides of the first block of each stage. + Default to (2, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default to (1, 1, 1, 1). + frozen_stages (int): Stages to be frozen + (all param fixed). -1 means not freezing any parameters. + Default to -1. + conv_cfg (dict | None): The config dict for conv layers. + Default to None. + norm_cfg (dict): The config dict for norm layers. + Default to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Default to ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default to False. + deploy (bool): Whether to switch the model structure to deployment + mode. Default to False. + norm_intermediate_features (bool): Construct and + config norm layer or not. + Using True will normalize the intermediate features + for downstream dense prediction tasks. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + arch_settings = { + '31B': + dict( + large_kernel_sizes=[31, 29, 27, 13], + layers=[2, 2, 18, 2], + channels=[128, 256, 512, 1024], + small_kernel=5, + dw_ratio=1), + '31L': + dict( + large_kernel_sizes=[31, 29, 27, 13], + layers=[2, 2, 18, 2], + channels=[192, 384, 768, 1536], + small_kernel=5, + dw_ratio=1), + 'XL': + dict( + large_kernel_sizes=[27, 27, 27, 13], + layers=[2, 2, 18, 2], + channels=[256, 512, 1024, 2048], + small_kernel=None, + dw_ratio=1.5), + } + + def __init__(self, + arch, + in_channels=3, + ffn_ratio=4, + out_indices=(3, ), + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + drop_path_rate=0.3, + small_kernel_merged=False, + norm_intermediate_features=False, + norm_eval=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(RepLKNet, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + assert len(arch['layers']) == len( + arch['channels']) == len(strides) == len(dilations) + assert max(out_indices) < len(arch['layers']) + + self.arch = arch + self.in_channels = in_channels + self.out_indices = out_indices + self.strides = strides + self.dilations = dilations + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_cp = with_cp + self.drop_path_rate = drop_path_rate + self.small_kernel_merged = small_kernel_merged + self.norm_eval = norm_eval + self.norm_intermediate_features = norm_intermediate_features + + self.out_indices = out_indices + + base_width = self.arch['channels'][0] + self.norm_intermediate_features = norm_intermediate_features + self.num_stages = len(self.arch['layers']) + self.stem = nn.ModuleList([ + conv_bn_relu( + in_channels=in_channels, + out_channels=base_width, + kernel_size=3, + stride=2, + padding=1, + groups=1), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=3, + stride=1, + padding=1, + groups=base_width), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=1, + stride=1, + padding=0, + groups=1), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=3, + stride=2, + padding=1, + groups=base_width) + ]) + # stochastic depth. We set block-wise drop-path rate. + # The higher level blocks are more likely to be dropped. + # This implementation follows Swin. + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(self.arch['layers'])) + ] + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + for stage_idx in range(self.num_stages): + layer = RepLKNetStage( + channels=self.arch['channels'][stage_idx], + num_blocks=self.arch['layers'][stage_idx], + stage_lk_size=self.arch['large_kernel_sizes'][stage_idx], + drop_path=dpr[sum(self.arch['layers'][:stage_idx] + ):sum(self.arch['layers'][:stage_idx + 1])], + small_kernel=self.arch['small_kernel'], + dw_ratio=self.arch['dw_ratio'], + ffn_ratio=ffn_ratio, + with_cp=with_cp, + small_kernel_merged=small_kernel_merged, + norm_intermediate_features=(stage_idx in out_indices)) + self.stages.append(layer) + if stage_idx < len(self.arch['layers']) - 1: + transition = nn.Sequential( + conv_bn_relu( + self.arch['channels'][stage_idx], + self.arch['channels'][stage_idx + 1], + 1, + 1, + 0, + groups=1), + conv_bn_relu( + self.arch['channels'][stage_idx + 1], + self.arch['channels'][stage_idx + 1], + 3, + stride=2, + padding=1, + groups=self.arch['channels'][stage_idx + 1])) + self.transitions.append(transition) + + def forward_features(self, x): + x = self.stem[0](x) + for stem_layer in self.stem[1:]: + if self.with_cp: + x = checkpoint.checkpoint(stem_layer, x) # save memory + else: + x = stem_layer(x) + + # Need the intermediate feature maps + outs = [] + for stage_idx in range(self.num_stages): + x = self.stages[stage_idx](x) + if stage_idx in self.out_indices: + outs.append(self.stages[stage_idx].norm(x)) + # For RepLKNet-XL normalize the features + # before feeding them into the heads + if stage_idx < self.num_stages - 1: + x = self.transitions[stage_idx](x) + return outs + + def forward(self, x): + x = self.forward_features(x) + return tuple(x) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RepLKNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + for m in self.modules(): + if hasattr(m, 'merge_kernel'): + m.merge_kernel() + self.small_kernel_merged = True diff --git a/mmpretrain/models/backbones/repmlp.py b/mmpretrain/models/backbones/repmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c06c4875710b33c57f2794c437034d93169b30 --- /dev/null +++ b/mmpretrain/models/backbones/repmlp.py @@ -0,0 +1,578 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from official impl at https://github.com/DingXiaoH/RepMLP. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.models.utils import SELayer, to_2tuple +from mmpretrain.registry import MODELS + + +def fuse_bn(conv_or_fc, bn): + """fuse conv and bn.""" + std = (bn.running_var + bn.eps).sqrt() + tmp_weight = bn.weight / std + tmp_weight = tmp_weight.reshape(-1, 1, 1, 1) + + if len(tmp_weight) == conv_or_fc.weight.size(0): + return (conv_or_fc.weight * tmp_weight, + bn.bias - bn.running_mean * bn.weight / std) + else: + # in RepMLPBlock, dim0 of fc3 weights and fc3_bn weights + # are different. + repeat_times = conv_or_fc.weight.size(0) // len(tmp_weight) + repeated = tmp_weight.repeat_interleave(repeat_times, 0) + fused_weight = conv_or_fc.weight * repeated + bias = bn.bias - bn.running_mean * bn.weight / std + fused_bias = (bias).repeat_interleave(repeat_times, 0) + return (fused_weight, fused_bias) + + +class PatchEmbed(_PatchEmbed): + """Image to Patch Embedding. + + Compared with default Patch Embedding(in ViT), Patch Embedding of RepMLP + have ReLu and do not convert output tensor into shape (N, L, C). + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The type of convolution + to generate patch embedding. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: 16. + padding (int | tuple | string): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only works when `dynamic_size` + is False. Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, *args, **kwargs): + super(PatchEmbed, self).__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): The output tensor. + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + if self.norm is not None: + x = self.norm(x) + x = self.relu(x) + out_size = (x.shape[2], x.shape[3]) + return x, out_size + + +class GlobalPerceptron(SELayer): + """GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``. + + Args: + input_channels (int): The number of input (and output) channels + in the GlobalPerceptron. + ratio (int): Squeeze ratio in GlobalPerceptron, the intermediate + channel will be ``make_divisible(channels // ratio, divisor)``. + """ + + def __init__(self, input_channels: int, ratio: int, **kwargs) -> None: + super(GlobalPerceptron, self).__init__( + channels=input_channels, + ratio=ratio, + return_weight=True, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + **kwargs) + + +class RepMLPBlock(BaseModule): + """Basic RepMLPNet, consists of PartitionPerceptron and GlobalPerceptron. + + Args: + channels (int): The number of input and the output channels of the + block. + path_h (int): The height of patches. + path_w (int): The weidth of patches. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + path_h, + path_w, + reparam_conv_kernels=None, + globalperceptron_ratio=4, + num_sharesets=1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + deploy=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.deploy = deploy + self.channels = channels + self.num_sharesets = num_sharesets + self.path_h, self.path_w = path_h, path_w + # the input channel of fc3 + self._path_vec_channles = path_h * path_w * num_sharesets + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.gp = GlobalPerceptron( + input_channels=channels, ratio=globalperceptron_ratio) + + # using a conv layer to implement a fc layer + self.fc3 = build_conv_layer( + conv_cfg, + in_channels=self._path_vec_channles, + out_channels=self._path_vec_channles, + kernel_size=1, + stride=1, + padding=0, + bias=deploy, + groups=num_sharesets) + if deploy: + self.fc3_bn = nn.Identity() + else: + norm_layer = build_norm_layer(norm_cfg, num_sharesets)[1] + self.add_module('fc3_bn', norm_layer) + + self.reparam_conv_kernels = reparam_conv_kernels + if not deploy and reparam_conv_kernels is not None: + for k in reparam_conv_kernels: + conv_branch = ConvModule( + in_channels=num_sharesets, + out_channels=num_sharesets, + kernel_size=k, + stride=1, + padding=k // 2, + norm_cfg=dict(type='BN', requires_grad=True), + groups=num_sharesets, + act_cfg=None) + self.__setattr__('repconv{}'.format(k), conv_branch) + + def partition(self, x, h_parts, w_parts): + # convert (N, C, H, W) to (N, h_parts, w_parts, C, path_h, path_w) + x = x.reshape(-1, self.channels, h_parts, self.path_h, w_parts, + self.path_w) + x = x.permute(0, 2, 4, 1, 3, 5) + return x + + def partition_affine(self, x, h_parts, w_parts): + """perform Partition Perceptron.""" + fc_inputs = x.reshape(-1, self._path_vec_channles, 1, 1) + out = self.fc3(fc_inputs) + out = out.reshape(-1, self.num_sharesets, self.path_h, self.path_w) + out = self.fc3_bn(out) + out = out.reshape(-1, h_parts, w_parts, self.num_sharesets, + self.path_h, self.path_w) + return out + + def forward(self, inputs): + # Global Perceptron + global_vec = self.gp(inputs) + + origin_shape = inputs.size() + h_parts = origin_shape[2] // self.path_h + w_parts = origin_shape[3] // self.path_w + + partitions = self.partition(inputs, h_parts, w_parts) + + # Channel Perceptron + fc3_out = self.partition_affine(partitions, h_parts, w_parts) + + # perform Local Perceptron + if self.reparam_conv_kernels is not None and not self.deploy: + conv_inputs = partitions.reshape(-1, self.num_sharesets, + self.path_h, self.path_w) + conv_out = 0 + for k in self.reparam_conv_kernels: + conv_branch = self.__getattr__('repconv{}'.format(k)) + conv_out += conv_branch(conv_inputs) + conv_out = conv_out.reshape(-1, h_parts, w_parts, + self.num_sharesets, self.path_h, + self.path_w) + fc3_out += conv_out + + # N, h_parts, w_parts, num_sharesets, out_h, out_w + fc3_out = fc3_out.permute(0, 3, 1, 4, 2, 5) + out = fc3_out.reshape(*origin_shape) + out = out * global_vec + return out + + def get_equivalent_fc3(self): + """get the equivalent fc3 weight and bias.""" + fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn) + if self.reparam_conv_kernels is not None: + largest_k = max(self.reparam_conv_kernels) + largest_branch = self.__getattr__('repconv{}'.format(largest_k)) + total_kernel, total_bias = fuse_bn(largest_branch.conv, + largest_branch.bn) + for k in self.reparam_conv_kernels: + if k != largest_k: + k_branch = self.__getattr__('repconv{}'.format(k)) + kernel, bias = fuse_bn(k_branch.conv, k_branch.bn) + total_kernel += F.pad(kernel, [(largest_k - k) // 2] * 4) + total_bias += bias + rep_weight, rep_bias = self._convert_conv_to_fc( + total_kernel, total_bias) + final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight + final_fc3_bias = rep_bias + fc_bias + else: + final_fc3_weight = fc_weight + final_fc3_bias = fc_bias + return final_fc3_weight, final_fc3_bias + + def local_inject(self): + """inject the Local Perceptron into Partition Perceptron.""" + self.deploy = True + # Locality Injection + fc3_weight, fc3_bias = self.get_equivalent_fc3() + # Remove Local Perceptron + if self.reparam_conv_kernels is not None: + for k in self.reparam_conv_kernels: + self.__delattr__('repconv{}'.format(k)) + self.__delattr__('fc3') + self.__delattr__('fc3_bn') + self.fc3 = build_conv_layer( + self.conv_cfg, + self._path_vec_channles, + self._path_vec_channles, + 1, + 1, + 0, + bias=True, + groups=self.num_sharesets) + self.fc3_bn = nn.Identity() + self.fc3.weight.data = fc3_weight + self.fc3.bias.data = fc3_bias + + def _convert_conv_to_fc(self, conv_kernel, conv_bias): + """convert conv_k1 to fc, which is still a conv_k2, and the k2 > k1.""" + in_channels = torch.eye(self.path_h * self.path_w).repeat( + 1, self.num_sharesets).reshape(self.path_h * self.path_w, + self.num_sharesets, self.path_h, + self.path_w).to(conv_kernel.device) + fc_k = F.conv2d( + in_channels, + conv_kernel, + padding=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), + groups=self.num_sharesets) + fc_k = fc_k.reshape(self.path_w * self.path_w, self.num_sharesets * + self.path_h * self.path_w).t() + fc_bias = conv_bias.repeat_interleave(self.path_h * self.path_w) + return fc_k, fc_bias + + +class RepMLPNetUnit(BaseModule): + """A basic unit in RepMLPNet : [REPMLPBlock + BN + ConvFFN + BN]. + + Args: + channels (int): The number of input and the output channels of the + unit. + path_h (int): The height of patches. + path_w (int): The weidth of patches. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + path_h, + path_w, + reparam_conv_kernels, + globalperceptron_ratio, + norm_cfg=dict(type='BN', requires_grad=True), + ffn_expand=4, + num_sharesets=1, + deploy=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.repmlp_block = RepMLPBlock( + channels=channels, + path_h=path_h, + path_w=path_w, + reparam_conv_kernels=reparam_conv_kernels, + globalperceptron_ratio=globalperceptron_ratio, + num_sharesets=num_sharesets, + deploy=deploy) + self.ffn_block = ConvFFN(channels, channels * ffn_expand) + norm1 = build_norm_layer(norm_cfg, channels)[1] + self.add_module('norm1', norm1) + norm2 = build_norm_layer(norm_cfg, channels)[1] + self.add_module('norm2', norm2) + + def forward(self, x): + y = x + self.repmlp_block(self.norm1(x)) + out = y + self.ffn_block(self.norm2(y)) + return out + + +class ConvFFN(nn.Module): + """ConvFFN implemented by using point-wise convs.""" + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='GELU')): + super().__init__() + out_features = out_channels or in_channels + hidden_features = hidden_channels or in_channels + self.ffn_fc1 = ConvModule( + in_channels=in_channels, + out_channels=hidden_features, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + act_cfg=None) + self.ffn_fc2 = ConvModule( + in_channels=hidden_features, + out_channels=out_features, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + act_cfg=None) + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + x = self.ffn_fc1(x) + x = self.act(x) + x = self.ffn_fc2(x) + return x + + +@MODELS.register_module() +class RepMLPNet(BaseModule): + """RepMLPNet backbone. + + A PyTorch impl of : `RepMLP: Re-parameterizing Convolutions into + Fully-connected Layers for Image Recognition + `_ + + Args: + arch (str | dict): RepMLP architecture. If use string, choose + from 'base' and 'b'. If use dict, it should have below keys: + + - channels (List[int]): Number of blocks in each stage. + - depths (List[int]): The number of blocks in each branch. + - sharesets_nums (List[int]): RepVGG Block that declares + the need to apply group convolution. + + img_size (int | tuple): The size of input image. Defaults: 224. + in_channels (int): Number of input image channels. Default: 3. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + Default: dict(type='BN', requires_grad=True). + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to deployment + mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + arch_zoo = { + **dict.fromkeys(['b', 'base'], + {'channels': [96, 192, 384, 768], + 'depths': [2, 2, 12, 2], + 'sharesets_nums': [1, 4, 32, 128]}), + } # yapf: disable + + num_extra_tokens = 0 # there is no cls-token in RepMLP + + def __init__(self, + arch, + img_size=224, + in_channels=3, + patch_size=4, + out_indices=(3, ), + reparam_conv_kernels=(3, ), + globalperceptron_ratio=4, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + patch_cfg=dict(), + final_norm=True, + deploy=False, + init_cfg=None): + super(RepMLPNet, self).__init__(init_cfg=init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'channels', 'depths', 'sharesets_nums'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}.' + self.arch_settings = arch + + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.num_stage = len(self.arch_settings['channels']) + for value in self.arch_settings.values(): + assert isinstance(value, list) and len(value) == self.num_stage, ( + 'Length of setting item in arch dict must be type of list and' + ' have the same length.') + + self.channels = self.arch_settings['channels'] + self.depths = self.arch_settings['depths'] + self.sharesets_nums = self.arch_settings['sharesets_nums'] + + _patch_cfg = dict( + in_channels=in_channels, + input_size=self.img_size, + embed_dims=self.channels[0], + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + norm_cfg=self.norm_cfg, + bias=False) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.patch_hs = [ + self.patch_resolution[0] // 2**i for i in range(self.num_stage) + ] + self.patch_ws = [ + self.patch_resolution[1] // 2**i for i in range(self.num_stage) + ] + + self.stages = ModuleList() + self.downsample_layers = ModuleList() + for stage_idx in range(self.num_stage): + # make stage layers + _stage_cfg = dict( + channels=self.channels[stage_idx], + path_h=self.patch_hs[stage_idx], + path_w=self.patch_ws[stage_idx], + reparam_conv_kernels=reparam_conv_kernels, + globalperceptron_ratio=globalperceptron_ratio, + norm_cfg=self.norm_cfg, + ffn_expand=4, + num_sharesets=self.sharesets_nums[stage_idx], + deploy=deploy) + stage_blocks = [ + RepMLPNetUnit(**_stage_cfg) + for _ in range(self.depths[stage_idx]) + ] + self.stages.append(Sequential(*stage_blocks)) + + # make downsample layers + if stage_idx < self.num_stage - 1: + self.downsample_layers.append( + ConvModule( + in_channels=self.channels[stage_idx], + out_channels=self.channels[stage_idx + 1], + kernel_size=2, + stride=2, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + + self.out_indice = out_indices + + if final_norm: + norm_layer = build_norm_layer(norm_cfg, self.channels[-1])[1] + else: + norm_layer = nn.Identity() + self.add_module('final_norm', norm_layer) + + def forward(self, x): + assert x.shape[2:] == self.img_size, \ + "The Rep-MLP doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + + outs = [] + + x, _ = self.patch_embed(x) + for i, stage in enumerate(self.stages): + x = stage(x) + + # downsample after each stage except last stage + if i < len(self.stages) - 1: + downsample = self.downsample_layers[i] + x = downsample(x) + + if i in self.out_indice: + if self.final_norm and i == len(self.stages) - 1: + out = self.final_norm(x) + else: + out = x + outs.append(out) + + return tuple(outs) + + def switch_to_deploy(self): + for m in self.modules(): + if hasattr(m, 'local_inject'): + m.local_inject() diff --git a/mmpretrain/models/backbones/repvgg.py b/mmpretrain/models/backbones/repvgg.py new file mode 100644 index 0000000000000000000000000000000000000000..67c9d147546eb2839a44749040a1a787ee5ce0ea --- /dev/null +++ b/mmpretrain/models/backbones/repvgg.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmengine.model import BaseModule, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from torch import nn + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .base_backbone import BaseBackbone + + +class RepVGGBlock(BaseModule): + """RepVGG block for RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1. + padding (int): Padding of the 3x3 convolution layer. + dilation (int): Dilation of the 3x3 convolution layer. + groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1. + padding_mode (str): Padding mode of the 3x3 convolution layer. + Default: 'zeros'. + se_cfg (None or dict): The configuration of the se module. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + padding=1, + dilation=1, + groups=1, + padding_mode='zeros', + se_cfg=None, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + deploy=False, + init_cfg=None): + super(RepVGGBlock, self).__init__(init_cfg) + + assert se_cfg is None or isinstance(se_cfg, dict) + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.se_cfg = se_cfg + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.deploy = deploy + + if deploy: + self.branch_reparam = build_conv_layer( + conv_cfg, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode) + else: + # judge if input shape and output shape are the same. + # If true, add a normalized identity shortcut. + if out_channels == in_channels and stride == 1 and \ + padding == dilation: + self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.branch_norm = None + + self.branch_3x3 = self.create_conv_bn( + kernel_size=3, + dilation=dilation, + padding=padding, + ) + self.branch_1x1 = self.create_conv_bn(kernel_size=1) + + if se_cfg is not None: + self.se_layer = SELayer(channels=out_channels, **se_cfg) + else: + self.se_layer = None + + self.act = build_activation_layer(act_cfg) + + def create_conv_bn(self, kernel_size, dilation=1, padding=0): + conv_bn = Sequential() + conv_bn.add_module( + 'conv', + build_conv_layer( + self.conv_cfg, + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + dilation=dilation, + padding=padding, + groups=self.groups, + bias=False)) + conv_bn.add_module( + 'norm', + build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) + + return conv_bn + + def forward(self, x): + + def _inner_forward(inputs): + if self.deploy: + return self.branch_reparam(inputs) + + if self.branch_norm is None: + branch_norm_out = 0 + else: + branch_norm_out = self.branch_norm(inputs) + + inner_out = self.branch_3x3(inputs) + self.branch_1x1( + inputs) + branch_norm_out + + if self.se_cfg is not None: + inner_out = self.se_layer(inner_out) + + return inner_out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.act(out) + + return out + + def switch_to_deploy(self): + """Switch the model structure from training mode to deployment mode.""" + if self.deploy: + return + assert self.norm_cfg['type'] == 'BN', \ + "Switch is not allowed when norm_cfg['type'] != 'BN'." + + reparam_weight, reparam_bias = self.reparameterize() + self.branch_reparam = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.out_channels, + kernel_size=3, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + self.branch_reparam.weight.data = reparam_weight + self.branch_reparam.bias.data = reparam_bias + + for param in self.parameters(): + param.detach_() + delattr(self, 'branch_3x3') + delattr(self, 'branch_1x1') + delattr(self, 'branch_norm') + + self.deploy = True + + def reparameterize(self): + """Fuse all the parameters of all branches. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all + branches. the first element is the weights and the second is + the bias. + """ + weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3) + weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1) + # pad a conv1x1 weight to a conv3x3 weight + weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0) + + weight_norm, bias_norm = 0, 0 + if self.branch_norm: + tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm) + weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) + + return (weight_3x3 + weight_1x1 + weight_norm, + bias_3x3 + bias_1x1 + bias_norm) + + def _fuse_conv_bn(self, branch): + """Fuse the parameters in a branch with a conv and bn. + + Args: + branch (mmcv.runner.Sequential): A branch with conv and bn. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + if branch is None: + return 0, 0 + conv_weight = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + + std = (running_var + eps).sqrt() + fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight + fused_bias = -running_mean * gamma / std + beta + + return fused_weight, fused_bias + + def _norm_to_conv3x3(self, branch_nrom): + """Convert a norm layer to a conv3x3-bn sequence. + + Args: + branch (nn.BatchNorm2d): A branch only with bn in the block. + + Returns: + tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and + bn. + """ + input_dim = self.in_channels // self.groups + conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3), + dtype=branch_nrom.weight.dtype) + + for i in range(self.in_channels): + conv_weight[i, i % input_dim, 1, 1] = 1 + conv_weight = conv_weight.to(branch_nrom.weight.device) + + tmp_conv3x3 = self.create_conv_bn(kernel_size=3) + tmp_conv3x3.conv.weight.data = conv_weight + tmp_conv3x3.norm = branch_nrom + return tmp_conv3x3 + + +class MTSPPF(BaseModule): + """MTSPPF block for YOLOX-PAI RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of pooling. Default: 5. + """ + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + kernel_size=5): + super().__init__() + hidden_features = in_channels // 2 # hidden channels + self.conv1 = ConvModule( + in_channels, + hidden_features, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + hidden_features * 4, + out_channels, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d( + kernel_size=kernel_size, stride=1, padding=kernel_size // 2) + + def forward(self, x): + x = self.conv1(x) + y1 = self.maxpool(x) + y2 = self.maxpool(y1) + return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1)) + + +@MODELS.register_module() +class RepVGG(BaseBackbone): + """RepVGG backbone. + + A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again + `_ + + Args: + arch (str | dict): RepVGG architecture. If use string, choose from + 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2', + 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should + have below keys: + + - **num_blocks** (Sequence[int]): Number of blocks in each stage. + - **width_factor** (Sequence[float]): Width deflator in each stage. + - **group_layer_map** (dict | None): RepVGG Block that declares + the need to apply group convolution. + - **se_cfg** (dict | None): SE Layer config. + - **stem_channels** (int, optional): The stem channels, the final + stem channels will be + ``min(stem_channels, base_channels*width_factor[0])``. + If not set here, 64 is used by default in the code. + + in_channels (int): Number of input image channels. Defaults to 3. + base_channels (int): Base channels of RepVGG backbone, work with + width_factor together. Defaults to 64. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(2, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + deploy (bool): Whether to switch the model structure to deployment + mode. Defaults to False. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + add_ppf (bool): Whether to use the MTSPPF block. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] + g2_layer_map = {layer: 2 for layer in groupwise_layers} + g4_layer_map = {layer: 4 for layer in groupwise_layers} + + arch_settings = { + 'A0': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[0.75, 0.75, 0.75, 2.5], + group_layer_map=None, + se_cfg=None), + 'A1': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[1, 1, 1, 2.5], + group_layer_map=None, + se_cfg=None), + 'A2': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[1.5, 1.5, 1.5, 2.75], + group_layer_map=None, + se_cfg=None), + 'B0': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[1, 1, 1, 2.5], + group_layer_map=None, + se_cfg=None, + stem_channels=64), + 'B1': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=None, + se_cfg=None), + 'B1g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B1g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=g4_layer_map, + se_cfg=None), + 'B2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=None, + se_cfg=None), + 'B2g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B2g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=g4_layer_map, + se_cfg=None), + 'B3': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=None, + se_cfg=None), + 'B3g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B3g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=g4_layer_map, + se_cfg=None), + 'D2se': + dict( + num_blocks=[8, 14, 24, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=None, + se_cfg=dict(ratio=16, divisor=1)), + 'yolox-pai-small': + dict( + num_blocks=[3, 5, 7, 3], + width_factor=[1, 1, 1, 1], + group_layer_map=None, + se_cfg=None, + stem_channels=32), + } + + def __init__(self, + arch, + in_channels=3, + base_channels=64, + out_indices=(3, ), + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + deploy=False, + norm_eval=False, + add_ppf=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(RepVGG, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + assert len(arch['num_blocks']) == len( + arch['width_factor']) == len(strides) == len(dilations) + assert max(out_indices) < len(arch['num_blocks']) + if arch['group_layer_map'] is not None: + assert max(arch['group_layer_map'].keys()) <= sum( + arch['num_blocks']) + + if arch['se_cfg'] is not None: + assert isinstance(arch['se_cfg'], dict) + + self.base_channels = base_channels + self.arch = arch + self.in_channels = in_channels + self.out_indices = out_indices + self.strides = strides + self.dilations = dilations + self.deploy = deploy + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + + # defaults to 64 to prevert BC-breaking if stem_channels + # not in arch dict; + # the stem channels should not be larger than that of stage1. + channels = min( + arch.get('stem_channels', 64), + int(self.base_channels * self.arch['width_factor'][0])) + self.stem = RepVGGBlock( + self.in_channels, + channels, + stride=2, + se_cfg=arch['se_cfg'], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + deploy=deploy) + + next_create_block_idx = 1 + self.stages = [] + for i in range(len(arch['num_blocks'])): + num_blocks = self.arch['num_blocks'][i] + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = int(self.base_channels * 2**i * + self.arch['width_factor'][i]) + + stage, next_create_block_idx = self._make_stage( + channels, out_channels, num_blocks, stride, dilation, + next_create_block_idx, init_cfg) + stage_name = f'stage_{i + 1}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + channels = out_channels + + if add_ppf: + self.ppf = MTSPPF( + out_channels, + out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + kernel_size=5) + else: + self.ppf = nn.Identity() + + def _make_stage(self, in_channels, out_channels, num_blocks, stride, + dilation, next_create_block_idx, init_cfg): + strides = [stride] + [1] * (num_blocks - 1) + dilations = [dilation] * num_blocks + + blocks = [] + for i in range(num_blocks): + groups = self.arch['group_layer_map'].get( + next_create_block_idx, + 1) if self.arch['group_layer_map'] is not None else 1 + blocks.append( + RepVGGBlock( + in_channels, + out_channels, + stride=strides[i], + padding=dilations[i], + dilation=dilations[i], + groups=groups, + se_cfg=self.arch['se_cfg'], + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy, + init_cfg=init_cfg)) + in_channels = out_channels + next_create_block_idx += 1 + + return Sequential(*blocks), next_create_block_idx + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + if i + 1 == len(self.stages): + x = self.ppf(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = getattr(self, f'stage_{i+1}') + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RepVGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + for m in self.modules(): + if isinstance(m, RepVGGBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/res2net.py b/mmpretrain/models/backbones/res2net.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9bb6df37a2d2c9d19e613faa50ce0103aff357 --- /dev/null +++ b/mmpretrain/models/backbones/res2net.py @@ -0,0 +1,317 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottle2neck(_Bottleneck): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + scales=4, + base_width=26, + base_channels=64, + stage_type='normal', + **kwargs): + """Bottle2neck block for Res2Net.""" + super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs) + assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.' + + mid_channels = out_channels // self.expansion + width = int(math.floor(mid_channels * (base_width / base_channels))) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width * scales, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + width * scales, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + + if stage_type == 'stage': + self.pool = nn.AvgPool2d( + kernel_size=3, stride=self.conv2_stride, padding=1) + + self.convs = ModuleList() + self.bns = ModuleList() + for i in range(scales - 1): + self.convs.append( + build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False)) + self.bns.append( + build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1]) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width * scales, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.stage_type = stage_type + self.scales = scales + self.width = width + delattr(self, 'conv2') + delattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + sp = self.convs[0](spx[0].contiguous()) + sp = self.relu(self.bns[0](sp)) + out = sp + for i in range(1, self.scales - 1): + if self.stage_type == 'stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp.contiguous()) + sp = self.relu(self.bns[i](sp)) + out = torch.cat((out, sp), 1) + + if self.stage_type == 'normal' and self.scales != 1: + out = torch.cat((out, spx[self.scales - 1]), 1) + elif self.stage_type == 'stage' and self.scales != 1: + out = torch.cat((out, self.pool(spx[self.scales - 1])), 1) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Res2Layer(Sequential): + """Res2Layer to build Res2Net style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + scales (int): Scales used in Res2Net. Default: 4 + base_width (int): Basic width of each scale. Default: 26 + drop_path_rate (float or np.ndarray): stochastic depth rate. + Default: 0. + """ + + def __init__(self, + block, + in_channels, + out_channels, + num_blocks, + stride=1, + avg_down=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + scales=4, + base_width=26, + drop_path_rate=0.0, + **kwargs): + self.block = block + + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + + downsample = None + if stride != 1 or in_channels != out_channels: + if avg_down: + downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + else: + downsample = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + stage_type='stage', + drop_path_rate=drop_path_rate[0], + **kwargs)) + in_channels = out_channels + for i in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + drop_path_rate=drop_path_rate[i], + **kwargs)) + super(Res2Layer, self).__init__(*layers) + + +@MODELS.register_module() +class Res2Net(ResNet): + """Res2Net backbone. + + A PyTorch implement of : `Res2Net: A New Multi-scale Backbone + Architecture `_ + + Args: + depth (int): Depth of Res2Net, choose from {50, 101, 152}. + scales (int): Scales used in Res2Net. Defaults to 4. + base_width (int): Basic width of each scale. Defaults to 26. + in_channels (int): Number of input image channels. Defaults to 3. + num_stages (int): Number of Res2Net stages. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Defaults to "pytorch". + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN', requires_grad=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmpretrain.models import Res2Net + >>> import torch + >>> model = Res2Net(depth=50, + ... scales=4, + ... base_width=26, + ... out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = model.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottle2neck, (3, 4, 6, 3)), + 101: (Bottle2neck, (3, 4, 23, 3)), + 152: (Bottle2neck, (3, 8, 36, 3)) + } + + def __init__(self, + scales=4, + base_width=26, + style='pytorch', + deep_stem=True, + avg_down=True, + init_cfg=None, + **kwargs): + self.scales = scales + self.base_width = base_width + super(Res2Net, self).__init__( + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + init_cfg=init_cfg, + **kwargs) + + def make_res_layer(self, **kwargs): + return Res2Layer( + scales=self.scales, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/resnest.py b/mmpretrain/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb438f042d606946fd7b69d73568f28563e0efa --- /dev/null +++ b/mmpretrain/models/backbones/resnest.py @@ -0,0 +1,339 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN')): + super(SplitAttentionConv2d, self).__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + return getattr(self, self.norm0_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + groups=1, + width_per_group=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs) + + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = SplitAttentionConv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152, 200}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)), + 269: (Bottleneck, (3, 30, 48, 8)) + } + + def __init__(self, + depth, + groups=1, + width_per_group=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.width_per_group = width_per_group + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(depth=depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4a254f7c2b76f03974e05194b39fbb802684873a --- /dev/null +++ b/mmpretrain/models/backbones/resnet.py @@ -0,0 +1,768 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + +eps = 1.0e-5 + + +class BasicBlock(BaseModule): + """BasicBlock for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the output channels of conv1. This is a + reserved argument in BasicBlock and should always be 1. Default: 1. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None. + style (str): `pytorch` or `caffe`. It is unused and reserved for + unified API with Bottleneck. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + expansion=1, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + drop_path_rate=0.0, + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert self.expansion == 1 + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, out_channels, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + 3, + padding=1, + bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = build_activation_layer(act_cfg) + self.downsample = downsample + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.drop_path(out) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. Default: 4. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None. + style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the stride-two + layer is the first 1x1 conv layer. Default: "pytorch". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + expansion=4, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU', inplace=True), + drop_path_rate=0.0, + init_cfg=None): + super(Bottleneck, self).__init__(init_cfg=init_cfg) + assert style in ['pytorch', 'caffe'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = build_activation_layer(act_cfg) + self.downsample = downsample + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + @property + def norm3(self): + return getattr(self, self.norm3_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.drop_path(out) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def get_expansion(block, expansion=None): + """Get the expansion of a residual block. + + The block expansion will be obtained by the following order: + + 1. If ``expansion`` is given, just return it. + 2. If ``block`` has the attribute ``expansion``, then return + ``block.expansion``. + 3. Return the default value according the the block type: + 1 for ``BasicBlock`` and 4 for ``Bottleneck``. + + Args: + block (class): The block class. + expansion (int | None): The given expansion ratio. + + Returns: + int: The expansion of the block. + """ + if isinstance(expansion, int): + assert expansion > 0 + elif expansion is None: + if hasattr(block, 'expansion'): + expansion = block.expansion + elif issubclass(block, BasicBlock): + expansion = 1 + elif issubclass(block, Bottleneck): + expansion = 4 + else: + raise TypeError(f'expansion is not specified for {block.__name__}') + else: + raise TypeError('expansion must be an integer or None') + + return expansion + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): Residual block used to build ResLayer. + num_blocks (int): Number of blocks. + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int, optional): The expansion for BasicBlock/Bottleneck. + If not specified, it will firstly be obtained via + ``block.expansion``. If the block has no attribute "expansion", + the following default values will be used: 1 for BasicBlock and + 4 for Bottleneck. Default: None. + stride (int): stride of the first block. Default: 1. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + drop_path_rate (float or list): stochastic depth rate. + Default: 0. + """ + + def __init__(self, + block, + num_blocks, + in_channels, + out_channels, + expansion=None, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + drop_path_rate=0.0, + **kwargs): + self.block = block + self.expansion = get_expansion(block, expansion) + + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[0], + **kwargs)) + in_channels = out_channels + for i in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[i], + **kwargs)) + super(ResLayer, self).__init__(*layers) + + +@MODELS.register_module() +class ResNet(BaseBackbone): + """ResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + expansion=None, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate=0.0): + super(ResNet, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.expansion = get_expansion(self.block, expansion) + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + _in_channels = stem_channels + _out_channels = base_channels * self.expansion + + # stochastic depth decay rule + total_depth = sum(stage_blocks) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + res_layer = self.make_res_layer( + block=self.block, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=_out_channels, + expansion=self.expansion, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=dpr[:num_blocks]) + _in_channels = _out_channels + _out_channels *= 2 + dpr = dpr[num_blocks:] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = res_layer[-1].out_channels + + def make_res_layer(self, **kwargs): + return ResLayer(**kwargs) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ResNet, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress zero_init_residual if use pretrained model. + return + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, x): + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer id to set the different learning rates for ResNet. + + ResNet stages: + 50 : [3, 4, 6, 3] + 101 : [3, 4, 23, 3] + 152 : [3, 8, 36, 3] + 200 : [3, 24, 36, 3] + eca269d: [3, 30, 48, 8] + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + depths = self.stage_blocks + if depths[1] == 4 and depths[2] == 6: + blk2, blk3 = 2, 3 + elif depths[1] == 4 and depths[2] == 23: + blk2, blk3 = 2, 3 + elif depths[1] == 8 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 24 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 30 and depths[2] == 48: + blk2, blk3 = 5, 6 + else: + raise NotImplementedError + + N2, N3 = math.ceil(depths[1] / blk2 - + 1e-5), math.ceil(depths[2] / blk3 - 1e-5) + N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6 + max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id, max_layer_id + 1 + + if param_name.startswith('backbone.layer'): + stage_id = int(param_name.split('.')[1][5:]) + block_id = int(param_name.split('.')[2]) + + if stage_id == 1: + layer_id = 1 + elif stage_id == 2: + layer_id = 2 + block_id // blk2 # r50: 2, 3 + elif stage_id == 3: + layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 + else: # stage_id == 4 + layer_id = N # r50: 6 + return layer_id, max_layer_id + 1 + + else: + return 0, max_layer_id + 1 + + +@MODELS.register_module() +class ResNetV1c(ResNet): + """ResNetV1c backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv + in the input stem with three 3x3 convs. + """ + + def __init__(self, **kwargs): + super(ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class ResNetV1d(ResNet): + """ResNetV1d backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/mmpretrain/models/backbones/resnet_cifar.py b/mmpretrain/models/backbones/resnet_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..9f17f92fd76a690ea90977b38ab2ea00345ba903 --- /dev/null +++ b/mmpretrain/models/backbones/resnet_cifar.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class ResNet_CIFAR(ResNet): + """ResNet backbone for CIFAR. + + Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in + conv1, and does not apply MaxPoolinng after stem. It has been proven to + be more efficient than standard ResNet in other public codebase, e.g., + `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): This network has specific designed stem, thus it is + asserted to be False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + def __init__(self, depth, deep_stem=False, **kwargs): + super(ResNet_CIFAR, self).__init__( + depth, deep_stem=deep_stem, **kwargs) + assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem' + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmpretrain/models/backbones/resnext.py b/mmpretrain/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..8858b7d3dffdcb20677e091fba4f5a1084d086a3 --- /dev/null +++ b/mmpretrain/models/backbones/resnext.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + **kwargs): + super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super(ResNeXt, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/revvit.py b/mmpretrain/models/backbones/revvit.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e6c28c943c83d0580634ac04450ee7ffc5f478 --- /dev/null +++ b/mmpretrain/models/backbones/revvit.py @@ -0,0 +1,671 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys + +import numpy as np +import torch +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from torch import nn +from torch.autograd import Function as Function + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) + + +class RevBackProp(Function): + """Custom Backpropagation function to allow (A) flushing memory in forward + and (B) activation recomputation reversibly in backward for gradient + calculation. + + Inspired by + https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py + """ + + @staticmethod + def forward( + ctx, + x, + layers, + buffer_layers, # List of layer ids for int activation to buffer + ): + """Reversible Forward pass. + + Any intermediate activations from `buffer_layers` are cached in ctx for + forward pass. This is not necessary for standard usecases. Each + reversible layer implements its own forward pass logic. + """ + buffer_layers.sort() + x1, x2 = torch.chunk(x, 2, dim=-1) + intermediate = [] + + for layer in layers: + x1, x2 = layer(x1, x2) + if layer.layer_id in buffer_layers: + intermediate.extend([x1.detach(), x2.detach()]) + + if len(buffer_layers) == 0: + all_tensors = [x1.detach(), x2.detach()] + else: + intermediate = [torch.LongTensor(buffer_layers), *intermediate] + all_tensors = [x1.detach(), x2.detach(), *intermediate] + + ctx.save_for_backward(*all_tensors) + ctx.layers = layers + + return torch.cat([x1, x2], dim=-1) + + @staticmethod + def backward(ctx, dx): + """Reversible Backward pass. + + Any intermediate activations from `buffer_layers` are recovered from + ctx. Each layer implements its own loic for backward pass (both + activation recomputation and grad calculation). + """ + d_x1, d_x2 = torch.chunk(dx, 2, dim=-1) + # retrieve params from ctx for backward + x1, x2, *int_tensors = ctx.saved_tensors + # no buffering + if len(int_tensors) != 0: + buffer_layers = int_tensors[0].tolist() + else: + buffer_layers = [] + + layers = ctx.layers + + for _, layer in enumerate(layers[::-1]): + if layer.layer_id in buffer_layers: + x1, x2, d_x1, d_x2 = layer.backward_pass( + y1=int_tensors[buffer_layers.index(layer.layer_id) * 2 + + 1], + y2=int_tensors[buffer_layers.index(layer.layer_id) * 2 + + 2], + d_y1=d_x1, + d_y2=d_x2, + ) + else: + x1, x2, d_x1, d_x2 = layer.backward_pass( + y1=x1, + y2=x2, + d_y1=d_x1, + d_y2=d_x2, + ) + + dx = torch.cat([d_x1, d_x2], dim=-1) + + del int_tensors + del d_x1, d_x2, x1, x2 + + return dx, None, None + + +class RevTransformerEncoderLayer(BaseModule): + """Reversible Transformer Encoder Layer. + + This module is a building block of Reversible Transformer Encoder, + which support backpropagation without storing activations. + The residual connection is not applied to the FFN layer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + Default: 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0 + drop_path_rate (float): stochastic depth rate. + Default 0.0 + num_fcs (int): The number of linear in FFN + Default: 2 + qkv_bias (bool): enable bias for qkv if True. + Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU') + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + layer_id (int): The layer id of current layer. Used in RevBackProp. + Default: 0 + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + layer_id: int = 0, + init_cfg=None): + super(RevTransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate) + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + act_cfg=act_cfg, + add_identity=False) + + self.layer_id = layer_id + self.seeds = {} + + def init_weights(self): + super(RevTransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def seed_cuda(self, key): + """Fix seeds to allow for stochastic elements such as dropout to be + reproduced exactly in activation recomputation in the backward pass.""" + # randomize seeds + # use cuda generator if available + if (hasattr(torch.cuda, 'default_generators') + and len(torch.cuda.default_generators) > 0): + # GPU + device_idx = torch.cuda.current_device() + seed = torch.cuda.default_generators[device_idx].seed() + else: + # CPU + seed = int(torch.seed() % sys.maxsize) + + self.seeds[key] = seed + torch.manual_seed(self.seeds[key]) + + def forward(self, x1, x2): + """ + Implementation of Reversible TransformerEncoderLayer + + ` + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + ` + """ + self.seed_cuda('attn') + # attention output + f_x2 = self.attn(self.ln1(x2)) + # apply droppath on attention output + self.seed_cuda('droppath') + f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2) + y1 = x1 + f_x2_dropped + + # free memory + if self.training: + del x1 + + # ffn output + self.seed_cuda('ffn') + g_y1 = self.ffn(self.ln2(y1)) + # apply droppath on ffn output + torch.manual_seed(self.seeds['droppath']) + g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1) + # final output + y2 = x2 + g_y1_dropped + + # free memory + if self.training: + del x2 + + return y1, y2 + + def backward_pass(self, y1, y2, d_y1, d_y2): + """Activation re-compute with the following equation. + + x2 = y2 - g(y1), g = FFN + x1 = y1 - f(x2), f = MSHA + """ + + # temporarily record intermediate activation for G + # and use them for gradient calculation of G + with torch.enable_grad(): + y1.requires_grad = True + + torch.manual_seed(self.seeds['ffn']) + g_y1 = self.ffn(self.ln2(y1)) + + torch.manual_seed(self.seeds['droppath']) + g_y1 = build_dropout(self.drop_path_cfg)(g_y1) + + g_y1.backward(d_y2, retain_graph=True) + + # activate recomputation is by design and not part of + # the computation graph in forward pass + with torch.no_grad(): + x2 = y2 - g_y1 + del g_y1 + + d_y1 = d_y1 + y1.grad + y1.grad = None + + # record F activation and calculate gradients on F + with torch.enable_grad(): + x2.requires_grad = True + + torch.manual_seed(self.seeds['attn']) + f_x2 = self.attn(self.ln1(x2)) + + torch.manual_seed(self.seeds['droppath']) + f_x2 = build_dropout(self.drop_path_cfg)(f_x2) + + f_x2.backward(d_y1, retain_graph=True) + + # propagate reverse computed activations at the + # start of the previous block + with torch.no_grad(): + x1 = y1 - f_x2 + del f_x2, y1 + + d_y2 = d_y2 + x2.grad + + x2.grad = None + x2 = x2.detach() + + return x1, x2, d_y1, d_y2 + + +class TwoStreamFusion(nn.Module): + """A general constructor for neural modules fusing two equal sized tensors + in forward. + + Args: + mode (str): The mode of fusion. Options are 'add', 'max', 'min', + 'avg', 'concat'. + """ + + def __init__(self, mode: str): + super().__init__() + self.mode = mode + + if mode == 'add': + self.fuse_fn = lambda x: torch.stack(x).sum(dim=0) + elif mode == 'max': + self.fuse_fn = lambda x: torch.stack(x).max(dim=0).values + elif mode == 'min': + self.fuse_fn = lambda x: torch.stack(x).min(dim=0).values + elif mode == 'avg': + self.fuse_fn = lambda x: torch.stack(x).mean(dim=0) + elif mode == 'concat': + self.fuse_fn = lambda x: torch.cat(x, dim=-1) + else: + raise NotImplementedError + + def forward(self, x): + # split the tensor into two halves in the channel dimension + x = torch.chunk(x, 2, dim=2) + return self.fuse_fn(x) + + +@MODELS.register_module() +class RevVisionTransformer(BaseBackbone): + """Reversible Vision Transformer. + + A PyTorch implementation of : `Reversible Vision Transformers + `_ # noqa: E501 + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + fusion_mode (str): The fusion mode of transformer layers. + Defaults to 'concat'. + no_custom_backward (bool): Whether to use custom backward. + Defaults to False. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], + { + # The same as the implementation in MAE + # + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + } + num_extra_tokens = 0 # The official RevViT doesn't have class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='avg_featmap', + with_cls_token=False, + frozen_stages=-1, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + fusion_mode='concat', + no_custom_backward=False, + init_cfg=None): + super(RevVisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + self.no_custom_backward = no_custom_backward + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + layer_id=i, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(RevTransformerEncoderLayer(**_layer_cfg)) + + # fusion operation for the final output + self.fusion_layer = TwoStreamFusion(mode=fusion_mode) + + self.frozen_stages = frozen_stages + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + def init_weights(self): + super(RevVisionTransformer, self).init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers) and self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = torch.cat([x, x], dim=-1) + + # forward with different conditions + if not self.training or self.no_custom_backward: + # in eval/inference model + executing_fn = RevVisionTransformer._forward_vanilla_bp + else: + # use custom backward when self.training=True. + executing_fn = RevBackProp.apply + + x = executing_fn(x, self.layers, []) + + if self.final_norm: + x = self.ln1(x) + x = self.fusion_layer(x) + + return (self._format_output(x, patch_resolution), ) + + @staticmethod + def _forward_vanilla_bp(hidden_state, layers, buffer=[]): + """Using reversible layers without reversible backpropagation. + + Debugging purpose only. Activated with self.no_custom_backward + """ + # split into ffn state(ffn_out) and attention output(attn_out) + ffn_out, attn_out = torch.chunk(hidden_state, 2, dim=-1) + del hidden_state + + for _, layer in enumerate(layers): + attn_out, ffn_out = layer(attn_out, ffn_out) + + return torch.cat([attn_out, ffn_out], dim=-1) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/riformer.py b/mmpretrain/models/backbones/riformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7cb4d37c2ac6f1479fd3c533c456f3b0a0c45e --- /dev/null +++ b/mmpretrain/models/backbones/riformer.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .poolformer import Mlp, PatchEmbed + + +class Affine(nn.Module): + """Affine Transformation module. + + Args: + in_features (int): Input dimension. + """ + + def __init__(self, in_features): + super().__init__() + self.affine = nn.Conv2d( + in_features, + in_features, + kernel_size=1, + stride=1, + padding=0, + groups=in_features, + bias=True) + + def forward(self, x): + return self.affine(x) - x + + +class RIFormerBlock(BaseModule): + """RIFormer Block. + + Args: + dim (int): Embedding dim. + mlp_ratio (float): Mlp expansion ratio. Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='GN', num_groups=1)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-5. + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + """ + + def __init__(self, + dim, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + layer_scale_init_value=1e-5, + deploy=False): + + super().__init__() + + if deploy: + self.norm_reparam = build_norm_layer(norm_cfg, dim)[1] + else: + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = Affine(in_features=dim) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + # The following two techniques are useful to train deep RIFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.norm_cfg = norm_cfg + self.dim = dim + self.deploy = deploy + + def forward(self, x): + if hasattr(self, 'norm_reparam'): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.norm_reparam(x)) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + def fuse_affine(self, norm, token_mixer): + gamma_affn = token_mixer.affine.weight.reshape(-1) + gamma_affn = gamma_affn - torch.ones_like(gamma_affn) + beta_affn = token_mixer.affine.bias + gamma_ln = norm.weight + beta_ln = norm.bias + return (gamma_ln * gamma_affn), (beta_ln * gamma_affn + beta_affn) + + def get_equivalent_scale_bias(self): + eq_s, eq_b = self.fuse_affine(self.norm1, self.token_mixer) + return eq_s, eq_b + + def switch_to_deploy(self): + if self.deploy: + return + eq_s, eq_b = self.get_equivalent_scale_bias() + self.norm_reparam = build_norm_layer(self.norm_cfg, self.dim)[1] + self.norm_reparam.weight.data = eq_s + self.norm_reparam.bias.data = eq_b + self.__delattr__('norm1') + if hasattr(self, 'token_mixer'): + self.__delattr__('token_mixer') + self.deploy = True + + +def basic_blocks(dim, + index, + layers, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + layer_scale_init_value=1e-5, + deploy=False): + """generate RIFormer blocks for a stage.""" + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + blocks.append( + RIFormerBlock( + dim, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + deploy=deploy, + )) + blocks = nn.Sequential(*blocks) + + return blocks + + +@MODELS.register_module() +class RIFormer(BaseBackbone): + """RIFormer. + + A PyTorch implementation of RIFormer introduced by: + `RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer `_ + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``RIFormer.arch_settings``. And if dict, it + should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - mlp_ratios (list[int]): Expansion ratio of MLPs. + - layer_scale_init_value (float): Init value for Layer Scale. + + Defaults to 'S12'. + + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + in_patch_size (int): The patch size of/? input image patch embedding. + Defaults to 7. + in_stride (int): The stride of input image patch embedding. + Defaults to 4. + in_pad (int): The padding of input image patch embedding. + Defaults to 2. + down_patch_size (int): The patch size of downsampling patch embedding. + Defaults to 3. + down_stride (int): The stride of downsampling patch embedding. + Defaults to 2. + down_pad (int): The padding of downsampling patch embedding. + Defaults to 1. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to -1, which means not freezing any parameters. + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims, --mlp_ratios: + # embedding dims and mlp ratios for the four stages + # --downsamples: flags to apply downsampling or not in four blocks + arch_settings = { + 's12': { + 'layers': [2, 2, 6, 2], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's24': { + 'layers': [4, 4, 12, 4], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm48': { + 'layers': [8, 8, 24, 8], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + } + + def __init__(self, + arch='s12', + in_channels=3, + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., + drop_path_rate=0., + out_indices=-1, + frozen_stages=-1, + init_cfg=None, + deploy=False): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + layers = arch['layers'] + embed_dims = arch['embed_dims'] + mlp_ratios = arch['mlp_ratios'] \ + if 'mlp_ratios' in arch else [4, 4, 4, 4] + layer_scale_init_value = arch['layer_scale_init_value'] \ + if 'layer_scale_init_value' in arch else 1e-5 + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chans=in_channels, + embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + stage = basic_blocks( + embed_dims[i], + i, + layers, + mlp_ratio=mlp_ratios[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + deploy=deploy) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], + embed_dim=embed_dims[i + 1])) + + self.network = nn.ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 7 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, + embed_dims[(i_layer + 1) // 2])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + self.deploy = deploy + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RIFormer, self).train(mode) + self._freeze_stages() + return self + + def switch_to_deploy(self): + for m in self.modules(): + if isinstance(m, RIFormerBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/seresnet.py b/mmpretrain/models/backbones/seresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4437c17fa06d62f57ac18a31967a35b4f44f190f --- /dev/null +++ b/mmpretrain/models/backbones/seresnet.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.utils.checkpoint as cp + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .resnet import Bottleneck, ResLayer, ResNet + + +class SEBottleneck(Bottleneck): + """SEBottleneck block for SEResNet. + + Args: + in_channels (int): The input channels of the SEBottleneck block. + out_channels (int): The output channel of the SEBottleneck block. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + + def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs): + super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs) + self.se_layer = SELayer(out_channels, ratio=se_ratio) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + out = self.se_layer(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class SEResNet(ResNet): + """SEResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import SEResNet + >>> import torch + >>> self = SEResNet(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 512, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, se_ratio=16, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SEResNet') + self.se_ratio = se_ratio + super(SEResNet, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer(se_ratio=self.se_ratio, **kwargs) diff --git a/mmpretrain/models/backbones/seresnext.py b/mmpretrain/models/backbones/seresnext.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2838074225930795d6d8ad70ba067b6ad4c2da --- /dev/null +++ b/mmpretrain/models/backbones/seresnext.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResLayer +from .seresnet import SEBottleneck as _SEBottleneck +from .seresnet import SEResNet + + +class SEBottleneck(_SEBottleneck): + """SEBottleneck block for SEResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + base_channels (int): Middle channels of the first stage. Default: 64. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + se_ratio=16, + **kwargs): + super(SEBottleneck, self).__init__(in_channels, out_channels, se_ratio, + **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # We follow the same rational of ResNext to compute mid_channels. + # For SEResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for SEResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class SEResNeXt(SEResNet): + """SEResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super(SEResNeXt, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/shufflenet_v1.py b/mmpretrain/models/backbones/shufflenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc3617f93b82fa5e37fa2bb5b47d93e6bd9a58f --- /dev/null +++ b/mmpretrain/models/backbones/shufflenet_v1.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import channel_shuffle, make_divisible +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class ShuffleUnit(BaseModule): + """ShuffleUnit block. + + ShuffleNet unit with pointwise group convolution (GConv) and channel + shuffle. + + Args: + in_channels (int): The input channels of the ShuffleUnit. + out_channels (int): The output channels of the ShuffleUnit. + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3 + first_block (bool): Whether it is the first ShuffleUnit of a + sequential ShuffleUnits. Default: True, which means not using the + grouped 1x1 convolution. + combine (str): The ways to combine the input and output + branches. Default: 'add'. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + groups=3, + first_block=True, + combine='add', + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(ShuffleUnit, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.first_block = first_block + self.combine = combine + self.groups = groups + self.bottleneck_channels = self.out_channels // 4 + self.with_cp = with_cp + + if self.combine == 'add': + self.depthwise_stride = 1 + self._combine_func = self._add + assert in_channels == out_channels, ( + 'in_channels must be equal to out_channels when combine ' + 'is add') + elif self.combine == 'concat': + self.depthwise_stride = 2 + self._combine_func = self._concat + self.out_channels -= self.in_channels + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + raise ValueError(f'Cannot combine tensors with {self.combine}. ' + 'Only "add" and "concat" are supported') + + self.first_1x1_groups = 1 if first_block else self.groups + self.g_conv_1x1_compress = ConvModule( + in_channels=self.in_channels, + out_channels=self.bottleneck_channels, + kernel_size=1, + groups=self.first_1x1_groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.depthwise_conv3x3_bn = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.bottleneck_channels, + kernel_size=3, + stride=self.depthwise_stride, + padding=1, + groups=self.bottleneck_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.g_conv_1x1_expand = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.out_channels, + kernel_size=1, + groups=self.groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.act = build_activation_layer(act_cfg) + + @staticmethod + def _add(x, out): + # residual connection + return x + out + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.g_conv_1x1_compress(x) + out = self.depthwise_conv3x3_bn(out) + + if self.groups > 1: + out = channel_shuffle(out, self.groups) + + out = self.g_conv_1x1_expand(out) + + if self.combine == 'concat': + residual = self.avgpool(residual) + out = self.act(out) + out = self._combine_func(residual, out) + else: + out = self._combine_func(residual, out) + out = self.act(out) + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV1(BaseBackbone): + """ShuffleNetV1 backbone. + + Args: + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3. + widen_factor (float): Width multiplier - adjusts the number + of channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, ) + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(2, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=None): + super(ShuffleNetV1, self).__init__(init_cfg) + self.init_cfg = init_cfg + self.stage_blocks = [4, 8, 4] + self.groups = groups + + for index in out_indices: + if index not in range(0, 3): + raise ValueError('the item in out_indices must in ' + f'range(0, 3). But received {index}') + + if frozen_stages not in range(-1, 3): + raise ValueError('frozen_stages must be in range(-1, 3). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if groups == 1: + channels = (144, 288, 576) + elif groups == 2: + channels = (200, 400, 800) + elif groups == 3: + channels = (240, 480, 960) + elif groups == 4: + channels = (272, 544, 1088) + elif groups == 8: + channels = (384, 768, 1536) + else: + raise ValueError(f'{groups} groups is not supported for 1x1 ' + 'Grouped Convolutions') + + channels = [make_divisible(ch * widen_factor, 8) for ch in channels] + + self.in_channels = int(24 * widen_factor) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + first_block = True if i == 0 else False + layer = self.make_layer(channels[i], num_blocks, first_block) + self.layers.append(layer) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV1, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def make_layer(self, out_channels, num_blocks, first_block=False): + """Stack ShuffleUnit blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): Number of blocks. + first_block (bool): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. Default: False, which means using + the grouped 1x1 convolution. + """ + layers = [] + for i in range(num_blocks): + first_block = first_block if i == 0 else False + combine_mode = 'concat' if i == 0 else 'add' + layers.append( + ShuffleUnit( + self.in_channels, + out_channels, + groups=self.groups, + first_block=first_block, + combine=combine_mode, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetV1, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/shufflenet_v2.py b/mmpretrain/models/backbones/shufflenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..02f9c749a814b0b4ee4e04dd6afacda078ae6f39 --- /dev/null +++ b/mmpretrain/models/backbones/shufflenet_v2.py @@ -0,0 +1,305 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import channel_shuffle +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class InvertedResidual(BaseModule): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + # Channel Split operation. using these lines of code to replace + # ``chunk(x, 2, dim=1)`` can make it easier to deploy a + # shufflenetv2 model by using mmdeploy. + channels = x.shape[1] + c = channels // 2 + channels % 2 + x1 = x[:, :c, :, :] + x2 = x[:, c:, :, :] + + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV2(BaseBackbone): + """ShuffleNetV2 backbone. + + Args: + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + widen_factor=1.0, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=None): + super(ShuffleNetV2, self).__init__(init_cfg) + self.stage_blocks = [4, 8, 4] + for index in out_indices: + if index not in range(0, 4): + raise ValueError('the item in out_indices must in ' + f'range(0, 4). But received {index}') + + if frozen_stages not in range(-1, 4): + raise ValueError('frozen_stages must be in range(-1, 4). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if widen_factor == 0.5: + channels = [48, 96, 192, 1024] + elif widen_factor == 1.0: + channels = [116, 232, 464, 1024] + elif widen_factor == 1.5: + channels = [176, 352, 704, 1024] + elif widen_factor == 2.0: + channels = [244, 488, 976, 2048] + else: + raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. ' + f'But received {widen_factor}') + + self.in_channels = 24 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + layer = self._make_layer(channels[i], num_blocks) + self.layers.append(layer) + + output_channels = channels[-1] + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=output_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def _make_layer(self, out_channels, num_blocks): + """Stack blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): number of blocks. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append( + InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m.weight, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/sparse_convnext.py b/mmpretrain/models/backbones/sparse_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..8f361360af460746a0f70206becb519252135596 --- /dev/null +++ b/mmpretrain/models/backbones/sparse_convnext.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (SparseAvgPooling, SparseConv2d, SparseHelper, + SparseMaxPooling, build_norm_layer) +from .convnext import ConvNeXt, ConvNeXtBlock + + +class SparseConvNeXtBlock(ConvNeXtBlock): + """Sparse ConvNeXt Block. + + Note: + There are two equivalent implementations: + 1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear -> + GELU -> Linear; Permute back + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class SparseConvNeXt(ConvNeXt): + """ConvNeXt with sparse module conversion function. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/models/convnext.py + and + https://github.com/keyu-tian/SparK/blob/main/encoder.py + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='SparseLN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_output (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict. + """ # noqa: E501 + + def __init__(self, + arch: str = 'small', + in_channels: int = 3, + stem_patch_size: int = 4, + norm_cfg: dict = dict(type='SparseLN2d', eps=1e-6), + act_cfg: dict = dict(type='GELU'), + linear_pw_conv: bool = True, + use_grn: bool = False, + drop_path_rate: float = 0, + layer_scale_init_value: float = 1e-6, + out_indices: int = -1, + frozen_stages: int = 0, + gap_before_output: bool = True, + with_cp: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super(ConvNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_output = gap_before_output + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + SparseConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + self.dense_model_to_sparse(m=self) + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + if self.gap_before_output: + gap = x.mean([-2, -1], keepdim=True) + outs.append(gap.flatten(1)) + else: + outs.append(x) + + return tuple(outs) + + def dense_model_to_sparse(self, m: nn.Module) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + # elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + # m: nn.BatchNorm2d + # output = (SparseSyncBatchNorm2d + # if enable_sync_bn else SparseBatchNorm2d)( + # m.weight.shape[0], + # eps=m.eps, + # momentum=m.momentum, + # affine=m.affine, + # track_running_stats=m.track_running_stats) + # output.weight.data.copy_(m.weight.data) + # output.bias.data.copy_(m.bias.data) + # output.running_mean.data.copy_(m.running_mean.data) + # output.running_var.data.copy_(m.running_var.data) + # output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + for name, child in m.named_children(): + output.add_module(name, self.dense_model_to_sparse(child)) + del m + return output diff --git a/mmpretrain/models/backbones/sparse_resnet.py b/mmpretrain/models/backbones/sparse_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..67597f1f0327f466a6841333c8247f96238ce35f --- /dev/null +++ b/mmpretrain/models/backbones/sparse_resnet.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import Optional, Tuple + +import torch.nn as nn + +from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling, + SparseBatchNorm2d, + SparseConv2d, + SparseMaxPooling, + SparseSyncBatchNorm2d) +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class SparseResNet(ResNet): + """ResNet with sparse module conversion function. + + Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Output channels of the stem layer. Defaults to 64. + base_channels (int): Middle channels of the first stage. + Defaults to 64. + num_stages (int): Stages of the network. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + """ + + def __init__(self, + depth: int, + in_channels: int = 3, + stem_channels: int = 64, + base_channels: int = 64, + expansion: Optional[int] = None, + num_stages: int = 4, + strides: Tuple[int] = (1, 2, 2, 2), + dilations: Tuple[int] = (1, 1, 1, 1), + out_indices: Tuple[int] = (3, ), + style: str = 'pytorch', + deep_stem: bool = False, + avg_down: bool = False, + frozen_stages: int = -1, + conv_cfg: Optional[dict] = None, + norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'), + norm_eval: bool = False, + with_cp: bool = False, + zero_init_residual: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate: float = 0, + **kwargs): + super().__init__( + depth=depth, + in_channels=in_channels, + stem_channels=stem_channels, + base_channels=base_channels, + expansion=expansion, + num_stages=num_stages, + strides=strides, + dilations=dilations, + out_indices=out_indices, + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + norm_eval=norm_eval, + with_cp=with_cp, + zero_init_residual=zero_init_residual, + init_cfg=init_cfg, + drop_path_rate=drop_path_rate, + **kwargs) + norm_type = norm_cfg['type'] + enable_sync_bn = False + if re.search('Sync', norm_type) is not None: + enable_sync_bn = True + self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn) + + def dense_model_to_sparse(self, m: nn.Module, + enable_sync_bn: bool) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + m: nn.BatchNorm2d + output = (SparseSyncBatchNorm2d + if enable_sync_bn else SparseBatchNorm2d)( + m.weight.shape[0], + eps=m.eps, + momentum=m.momentum, + affine=m.affine, + track_running_stats=m.track_running_stats) + output.weight.data.copy_(m.weight.data) + output.bias.data.copy_(m.bias.data) + output.running_mean.data.copy_(m.running_mean.data) + output.running_var.data.copy_(m.running_var.data) + output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + elif isinstance(m, (nn.Conv1d, )): + raise NotImplementedError + + for name, child in m.named_children(): + output.add_module( + name, + self.dense_model_to_sparse( + child, enable_sync_bn=enable_sync_bn)) + del m + return output diff --git a/mmpretrain/models/backbones/swin_transformer.py b/mmpretrain/models/backbones/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..559fd5e9150f78a9801fcb9070e114b4e96113c5 --- /dev/null +++ b/mmpretrain/models/backbones/swin_transformer.py @@ -0,0 +1,585 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import (ShiftWindowMSA, resize_pos_embed, + resize_relative_position_bias_table, to_2tuple) +from .base_backbone import BaseBackbone + + +class SwinBlock(BaseModule): + """Swin Transformer block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + shift=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SwinBlock, self).__init__(init_cfg) + self.with_cp = with_cp + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA(**_attn_cfgs) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=7, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + self.embed_dims = embed_dims + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **block_cfgs[i] + } + block = SwinBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': 2 * embed_dims, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x, in_shape, do_downsample=True): + for block in self.blocks: + x = block(x, in_shape) + + if self.downsample is not None and do_downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + return x, out_shape + + @property + def out_channels(self): + if self.downsample: + return self.downsample.out_channels + else: + return self.embed_dims + + +@MODELS.register_module() +class SwinTransformer(BaseBackbone): + """Swin Transformer. + + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SwinTransformer + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'expansion_ratio': 3})) + >>> self = SwinTransformer(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48]}), + } # yapf: disable + + _version = 3 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=224, + patch_size=4, + in_channels=3, + window_size=7, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + out_after_downsample=False, + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + init_cfg=None): + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.out_after_downsample = out_after_downsample + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook( + self._prepare_relative_position_bias_table) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': window_size, + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **stage_cfg + } + + stage = SwinBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + if self.out_after_downsample: + self.num_features = embed_dims[1:] + else: + self.num_features = embed_dims[:-1] + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + super(SwinTransformer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) + + return tuple(outs) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, + **kwargs): + """load checkpoints.""" + # Names of some parameters in has been changed. + version = local_metadata.get('version', None) + if (version is None + or version < 2) and self.__class__ is SwinTransformer: + final_stage_num = len(self.stages) - 1 + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if k.startswith('norm.') or k.startswith('backbone.norm.'): + convert_key = k.replace('norm.', f'norm{final_stage_num}.') + state_dict[convert_key] = state_dict[k] + del state_dict[k] + if (version is None + or version < 3) and self.__class__ is SwinTransformer: + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if 'attn_mask' in k: + del state_dict[k] + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + *args, **kwargs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformer, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, + **kwargs): + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'relative_position_bias_table' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + relative_position_bias_table_pretrained = state_dict[ckpt_key] + relative_position_bias_table_current = state_dict_model[key] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if L1 != L2: + src_size = int(L1**0.5) + dst_size = int(L2**0.5) + new_rel_pos_bias = resize_relative_position_bias_table( + src_size, dst_size, + relative_position_bias_table_pretrained, nH1) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info('Resize the relative_position_bias_table from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos_bias.shape}') + state_dict[ckpt_key] = new_rel_pos_bias + + # The index buffer need to be re-generated. + index_buffer = ckpt_key.replace('bias_table', 'index') + del state_dict[index_buffer] + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = sum(self.depths) + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = param_name.split('.')[3] + if block_id in ('reduction', 'norm'): + layer_depth = sum(self.depths[:stage_id + 1]) + else: + layer_depth = sum(self.depths[:stage_id]) + int(block_id) + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/swin_transformer_v2.py b/mmpretrain/models/backbones/swin_transformer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..142505a808ae3fc631d54e1a56ae483db242da31 --- /dev/null +++ b/mmpretrain/models/backbones/swin_transformer_v2.py @@ -0,0 +1,567 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from ..builder import MODELS +from ..utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2, + resize_pos_embed, to_2tuple) +from .base_backbone import BaseBackbone + + +class SwinBlockV2(BaseModule): + """Swin Transformer V2 block. Use post normalization. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + extra_norm (bool): Whether add extra norm at the end of main branch. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=8, + shift=False, + extra_norm=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained_window_size=0, + init_cfg=None): + + super(SwinBlockV2, self).__init__(init_cfg) + self.with_cp = with_cp + self.extra_norm = extra_norm + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + # use V2 attention implementation + _attn_cfgs.update( + window_msa=WindowMSAV2, + pretrained_window_size=to_2tuple(pretrained_window_size)) + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + 'add_identity': False, + **ffn_cfgs + } + self.ffn = FFN(**_ffn_cfgs) + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + # add extra norm for every n blocks in huge and giant model + if self.extra_norm: + self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + + def _inner_forward(x): + # Use post normalization + identity = x + x = self.attn(x, hw_shape) + x = self.norm1(x) + x = x + identity + + identity = x + x = self.ffn(x) + x = self.norm2(x) + x = x + identity + + if self.extra_norm: + x = self.norm3(x) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockV2Sequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + extra_norm_every_n_blocks (int): Add extra norm at the end of main + branch every n blocks. Defaults to 0, which means no needs for + extra norm layer. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=8, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + extra_norm_every_n_blocks=0, + pretrained_window_size=0, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + if downsample: + self.out_channels = 2 * embed_dims + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': self.out_channels, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.out_channels = embed_dims + self.downsample = None + + self.blocks = ModuleList() + for i in range(depth): + extra_norm = True if extra_norm_every_n_blocks and \ + (i + 1) % extra_norm_every_n_blocks == 0 else False + _block_cfg = { + 'embed_dims': self.out_channels, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'extra_norm': extra_norm, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'pretrained_window_size': pretrained_window_size, + **block_cfgs[i] + } + block = SwinBlockV2(**_block_cfg) + self.blocks.append(block) + + def forward(self, x, in_shape): + if self.downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + + for block in self.blocks: + x = block(x, out_shape) + + return x, out_shape + + +@MODELS.register_module() +class SwinTransformerV2(BaseBackbone): + """Swin Transformer V2. + + A PyTorch implement of : `Swin Transformer V2: + Scaling Up Capacity and Resolution + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + - **extra_norm_every_n_blocks** (int): Add extra norm at the end + of main branch every n blocks. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int | Sequence): The height and width of the window. + Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + pretrained_window_sizes (tuple(int)): Pretrained window sizes of + each layer. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SwinTransformerV2 + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'padding': 'same'})) + >>> self = SwinTransformerV2(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48], + 'extra_norm_every_n_blocks': 0}), + # head count not certain for huge, and is employed for another + # parallel study about self-supervised learning. + **dict.fromkeys(['h', 'huge'], + {'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [8, 16, 32, 64], + 'extra_norm_every_n_blocks': 6}), + **dict.fromkeys(['g', 'giant'], + {'embed_dims': 512, + 'depths': [2, 2, 42, 4], + 'num_heads': [16, 32, 64, 128], + 'extra_norm_every_n_blocks': 6}), + } # yapf: disable + + _version = 1 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=256, + patch_size=4, + in_channels=3, + window_size=8, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + pretrained_window_sizes=[0, 0, 0, 0], + init_cfg=None): + super(SwinTransformerV2, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', + 'extra_norm_every_n_blocks' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.extra_norm_every_n_blocks = self.arch_settings[ + 'extra_norm_every_n_blocks'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + if isinstance(window_size, int): + self.window_sizes = [window_size for _ in range(self.num_layers)] + elif isinstance(window_size, Sequence): + assert len(window_size) == self.num_layers, \ + f'Length of window_sizes {len(window_size)} is not equal to '\ + f'length of stages {self.num_layers}.' + self.window_sizes = window_size + else: + raise TypeError('window_size should be a Sequence or int.') + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook(self._delete_reinit_params) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i > 0 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': self.window_sizes[i], + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks, + 'pretrained_window_size': pretrained_window_sizes[i], + 'downsample_cfg': dict(use_post_norm=True), + **stage_cfg + } + + stage = SwinBlockV2Sequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + super(SwinTransformerV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformerV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs): + # delete relative_position_index since we always re-init it + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Delete `relative_position_index` and `relative_coords_table` ' + 'since we always re-init these params according to the ' + '`window_size`, which might cause unwanted but unworried ' + 'warnings when loading checkpoint.') + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete relative_coords_table since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_coords_table' in k + ] + for k in relative_position_index_keys: + del state_dict[k] diff --git a/mmpretrain/models/backbones/t2t_vit.py b/mmpretrain/models/backbones/t2t_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..a57b95e1fb00b227c400e7b32fa612e3539503c6 --- /dev/null +++ b/mmpretrain/models/backbones/t2t_vit.py @@ -0,0 +1,447 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) +from .base_backbone import BaseBackbone + + +class T2TTransformerLayer(BaseModule): + """Transformer Layer for T2T_ViT. + + Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports + different ``input_dims`` and ``embed_dims``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs + input_dims (int, optional): The input token dimension. + Defaults to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``(input_dims // num_heads) ** -0.5`` if set. Defaults to None. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Notes: + In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e. + ``(embed_dims // num_heads) ** -0.5``. However, in the official + code, it uses ``(input_dims // num_heads) ** -0.5``, so here we + keep the same with the official implementation. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + input_dims=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg) + + self.v_shortcut = True if input_dims is not None else False + input_dims = input_dims or embed_dims + + self.ln1 = build_norm_layer(norm_cfg, input_dims) + + self.attn = MultiheadAttention( + input_dims=input_dims, + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + qk_scale=qk_scale or (input_dims // num_heads)**-0.5, + v_shortcut=self.v_shortcut) + + self.ln2 = build_norm_layer(norm_cfg, embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + def forward(self, x): + if self.v_shortcut: + x = self.attn(self.ln1(x)) + else: + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +class T2TModule(BaseModule): + """Tokens-to-Token module. + + "Tokens-to-Token module" (T2T Module) can model the local structure + information of images and reduce the length of tokens progressively. + + Args: + img_size (int): Input image size + in_channels (int): Number of input channels + embed_dims (int): Embedding dimension + token_dims (int): Tokens dimension in T2TModuleAttention. + use_performer (bool): If True, use Performer version self-attention to + adopt regular self-attention. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + + Notes: + Usually, ``token_dim`` is set as a small value (32 or 64) to reduce + MACs + """ + + def __init__( + self, + img_size=224, + in_channels=3, + embed_dims=384, + token_dims=64, + use_performer=False, + init_cfg=None, + ): + super(T2TModule, self).__init__(init_cfg) + + self.embed_dims = embed_dims + + self.soft_split0 = nn.Unfold( + kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + if not use_performer: + self.attention1 = T2TTransformerLayer( + input_dims=in_channels * 7 * 7, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.attention2 = T2TTransformerLayer( + input_dims=token_dims * 3 * 3, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.project = nn.Linear(token_dims * 3 * 3, embed_dims) + else: + raise NotImplementedError("Performer hasn't been implemented.") + + # there are 3 soft split, stride are 4,2,2 separately + out_side = img_size // (4 * 2 * 2) + self.init_out_size = [out_side, out_side] + self.num_patches = out_side**2 + + @staticmethod + def _get_unfold_size(unfold: nn.Unfold, input_size): + h, w = input_size + kernel_size = to_2tuple(unfold.kernel_size) + stride = to_2tuple(unfold.stride) + padding = to_2tuple(unfold.padding) + dilation = to_2tuple(unfold.dilation) + + h_out = (h + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (w + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + return (h_out, w_out) + + def forward(self, x): + # step0: soft split + hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:]) + x = self.soft_split0(x).transpose(1, 2) + + for step in [1, 2]: + # re-structurization/reconstruction + attn = getattr(self, f'attention{step}') + x = attn(x).transpose(1, 2) + B, C, _ = x.shape + x = x.reshape(B, C, hw_shape[0], hw_shape[1]) + + # soft split + soft_split = getattr(self, f'soft_split{step}') + hw_shape = self._get_unfold_size(soft_split, hw_shape) + x = soft_split(x).transpose(1, 2) + + # final tokens + x = self.project(x) + return x, hw_shape + + +def get_sinusoid_encoding(n_position, embed_dims): + """Generate sinusoid encoding table. + + Sinusoid encoding is a kind of relative position encoding method came from + `Attention Is All You Need`_. + + Args: + n_position (int): The length of the input token. + embed_dims (int): The position embedding dimension. + + Returns: + :obj:`torch.FloatTensor`: The sinusoid encoding table. + """ + + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (i // 2) / embed_dims) + for i in range(embed_dims) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos) for pos in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +@MODELS.register_module() +class T2T_ViT(BaseBackbone): + """Tokens-to-Token Vision Transformer (T2T-ViT) + + A PyTorch implementation of `Tokens-to-Token ViT: Training Vision + Transformers from Scratch on ImageNet `_ + + Args: + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + in_channels (int): Number of input channels. + embed_dims (int): Embedding dimension. + num_layers (int): Num of transformer layers in encoder. + Defaults to 14. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Dropout rate after position embedding. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. Defaults to + ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + t2t_cfg (dict): Extra config of Tokens-to-Token module. + Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=384, + num_layers=14, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + interpolate_mode='bicubic', + t2t_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super().__init__(init_cfg) + + # Token-to-Token Module + self.tokens_to_token = T2TModule( + img_size=img_size, + in_channels=in_channels, + embed_dims=embed_dims, + **t2t_cfg) + self.patch_resolution = self.tokens_to_token.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + sinusoid_table = get_sinusoid_encoding( + num_patches + self.num_extra_tokens, embed_dims) + self.register_buffer('pos_embed', sinusoid_table) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must be a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = num_layers + index + assert 0 <= out_indices[i] <= num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] + + self.encoder = ModuleList() + for i in range(num_layers): + if isinstance(layer_cfgs, Sequence): + layer_cfg = layer_cfgs[i] + else: + layer_cfg = deepcopy(layer_cfgs) + layer_cfg = { + 'embed_dims': embed_dims, + 'num_heads': 6, + 'feedforward_channels': 3 * embed_dims, + 'drop_path_rate': dpr[i], + 'qkv_bias': False, + 'norm_cfg': norm_cfg, + **layer_cfg + } + + layer = T2TTransformerLayer(**layer_cfg) + self.encoder.append(layer) + + self.final_norm = final_norm + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims) + else: + self.norm = nn.Identity() + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress custom init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.tokens_to_token.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.tokens_to_token(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.encoder): + x = layer(x) + + if i == len(self.encoder) - 1 and self.final_norm: + x = self.norm(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/timm_backbone.py b/mmpretrain/models/backbones/timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..51ecbdbb077be0643026de2ec91c0169263a41f7 --- /dev/null +++ b/mmpretrain/models/backbones/timm_backbone.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmengine.logging import MMLogger + +from mmpretrain.registry import MODELS +from mmpretrain.utils import require +from .base_backbone import BaseBackbone + + +def print_timm_feature_info(feature_info): + """Print feature_info of timm backbone to help development and debug. + + Args: + feature_info (list[dict] | timm.models.features.FeatureInfo | None): + feature_info of timm backbone. + """ + logger = MMLogger.get_current_instance() + if feature_info is None: + logger.warning('This backbone does not have feature_info') + elif isinstance(feature_info, list): + for feat_idx, each_info in enumerate(feature_info): + logger.info(f'backbone feature_info[{feat_idx}]: {each_info}') + else: + try: + logger.info(f'backbone out_indices: {feature_info.out_indices}') + logger.info(f'backbone out_channels: {feature_info.channels()}') + logger.info(f'backbone out_strides: {feature_info.reduction()}') + except AttributeError: + logger.warning('Unexpected format of backbone feature_info') + + +@MODELS.register_module() +class TIMMBackbone(BaseBackbone): + """Wrapper to use backbones from timm library. + + More details can be found in + `timm `_. + See especially the document for `feature extraction + `_. + + Args: + model_name (str): Name of timm model to instantiate. + features_only (bool): Whether to extract feature pyramid (multi-scale + feature maps from the deepest layer at each stride). For Vision + Transformer models that do not support this argument, + set this False. Defaults to False. + pretrained (bool): Whether to load pretrained weights. + Defaults to False. + checkpoint_path (str): Path of checkpoint to load at the last of + ``timm.create_model``. Defaults to empty string, which means + not loading. + in_channels (int): Number of input image channels. Defaults to 3. + init_cfg (dict or list[dict], optional): Initialization config dict of + OpenMMLab projects. Defaults to None. + **kwargs: Other timm & model specific arguments. + """ + + @require('timm') + def __init__(self, + model_name, + features_only=False, + pretrained=False, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs): + import timm + + if not isinstance(pretrained, bool): + raise TypeError('pretrained must be bool, not str for model path') + if features_only and checkpoint_path: + warnings.warn( + 'Using both features_only and checkpoint_path will cause error' + ' in timm. See ' + 'https://github.com/rwightman/pytorch-image-models/issues/488') + + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + norm_class = MODELS.get(kwargs['norm_layer']) + + def build_norm(*args, **kwargs): + return norm_class(*args, **kwargs) + + kwargs['norm_layer'] = build_norm + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs) + + # reset classifier + if hasattr(self.timm_model, 'reset_classifier'): + self.timm_model.reset_classifier(0, '') + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + feature_info = getattr(self.timm_model, 'feature_info', None) + print_timm_feature_info(feature_info) + + def forward(self, x): + features = self.timm_model(x) + if isinstance(features, (list, tuple)): + features = tuple(features) + else: + features = (features, ) + return features diff --git a/mmpretrain/models/backbones/tinyvit.py b/mmpretrain/models/backbones/tinyvit.py new file mode 100644 index 0000000000000000000000000000000000000000..5279832184343a6e8ff4b253891de1b990192775 --- /dev/null +++ b/mmpretrain/models/backbones/tinyvit.py @@ -0,0 +1,769 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn import functional as F + +from mmpretrain.registry import MODELS +from ..utils import LeAttention +from .base_backbone import BaseBackbone + + +class ConvBN2d(Sequential): + """An implementation of Conv2d + BatchNorm2d with support of fusion. + + Modified from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int): The size of the convolution kernel. + Default: 1. + stride (int): The stride of the convolution. + Default: 1. + padding (int): The padding of the convolution. + Default: 0. + dilation (int): The dilation of the convolution. + Default: 1. + groups (int): The number of groups in the convolution. + Default: 1. + bn_weight_init (float): The initial value of the weight of + the nn.BatchNorm2d layer. Default: 1.0. + init_cfg (dict): The initialization config of the module. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bn_weight_init=1.0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.add_module( + 'conv2d', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False)) + bn2d = nn.BatchNorm2d(num_features=out_channels) + # bn initialization + torch.nn.init.constant_(bn2d.weight, bn_weight_init) + torch.nn.init.constant_(bn2d.bias, 0) + self.add_module('bn2d', bn2d) + + @torch.no_grad() + def fuse(self): + conv2d, bn2d = self._modules.values() + w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5 + w = conv2d.weight * w[:, None, None, None] + b = bn2d.bias - bn2d.running_mean * bn2d.weight / \ + (bn2d.running_var + bn2d.eps)**0.5 + + m = nn.Conv2d( + in_channels=w.size(1) * self.c.groups, + out_channels=w.size(0), + kernel_size=w.shape[2:], + stride=self.conv2d.stride, + padding=self.conv2d.padding, + dilation=self.conv2d.dilation, + groups=self.conv2d.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchEmbed(BaseModule): + """Patch Embedding for Vision Transformer. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use + Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is + (N, C, H, W). + + Args: + in_channels (int): The number of input channels. + embed_dim (int): The embedding dimension. + resolution (Tuple[int, int]): The resolution of the input feature. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + embed_dim, + resolution, + act_cfg=dict(type='GELU')): + super().__init__() + img_size: Tuple[int, int] = resolution + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_channels = in_channels + self.embed_dim = embed_dim + self.seq = nn.Sequential( + ConvBN2d( + in_channels, + embed_dim // 2, + kernel_size=3, + stride=2, + padding=1), + build_activation_layer(act_cfg), + ConvBN2d( + embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + ) + + def forward(self, x): + return self.seq(x) + + +class PatchMerging(nn.Module): + """Patch Merging for TinyViT. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Different from `mmpretrain.models.utils.PatchMerging`, this module use + Conv2d and BatchNorm2d to implement PatchMerging. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + out_channels (int): The number of output channels. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + resolution, + in_channels, + out_channels, + act_cfg=dict(type='GELU')): + super().__init__() + + self.img_size = resolution + + self.act = build_activation_layer(act_cfg) + self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1) + self.conv2 = ConvBN2d( + out_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels) + self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1) + self.out_resolution = (resolution[0] // 2, resolution[1] // 2) + + def forward(self, x): + if len(x.shape) == 3: + H, W = self.img_size + B = x.shape[0] + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + x = self.conv1(x) + x = self.act(x) + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + + x = x.flatten(2).transpose(1, 2) + return x + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + expand_ratio (int): The expand ratio of the hidden channels. + drop_rate (float): The drop rate of the block. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + out_channels, + expand_ratio, + drop_path, + act_cfg=dict(type='GELU')): + super().__init__() + self.in_channels = in_channels + hidden_channels = int(in_channels * expand_ratio) + + # linear + self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1) + self.act = build_activation_layer(act_cfg) + # depthwise conv + self.conv2 = ConvBN2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_channels) + # linear + self.conv3 = ConvBN2d( + hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act(x) + + return x + + +class ConvStage(BaseModule): + """Convolution Stage for TinyViT. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + depth (int): The number of blocks in the stage. + act_cfg (dict): The activation config of the module. + drop_path (float): The drop path of the block. + downsample (None | nn.Module): The downsample operation. + Default: None. + use_checkpoint (bool): Whether to use checkpointing to save memory. + out_channels (int): The number of output channels. + conv_expand_ratio (int): The expand ratio of the hidden channels. + Default: 4. + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + resolution, + depth, + act_cfg, + drop_path=0., + downsample=None, + use_checkpoint=False, + out_channels=None, + conv_expand_ratio=4., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.use_checkpoint = use_checkpoint + # build blocks + self.blocks = ModuleList([ + MBConvBlock( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=conv_expand_ratio, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path) + for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + resolution=resolution, + in_channels=in_channels, + out_channels=out_channels, + act_cfg=act_cfg) + self.resolution = self.downsample.out_resolution + else: + self.downsample = None + self.resolution = resolution + + def forward(self, x): + for block in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +class MLP(BaseModule): + """MLP module for TinyViT. + + Args: + in_channels (int): The number of input channels. + hidden_channels (int, optional): The number of hidden channels. + Default: None. + out_channels (int, optional): The number of output channels. + Default: None. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + drop (float): Probability of an element to be zeroed. + Default: 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.norm = nn.LayerNorm(in_channels) + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.fc2 = nn.Linear(hidden_channels, out_channels) + self.act = build_activation_layer(act_cfg) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class TinyViTBlock(BaseModule): + """TinViT Block. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + num_heads (int): The number of heads in the multi-head attention. + window_size (int): The size of the window. + Default: 7. + mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float): Probability of an element to be zeroed. + Default: 0. + drop_path (float): The drop path of the block. + Default: 0. + local_conv_size (int): The size of the local convolution. + Default: 3. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + resolution, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + act_cfg=dict(type='GELU')): + super().__init__() + self.in_channels = in_channels + self.img_size = resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert in_channels % num_heads == 0, \ + 'dim must be divisible by num_heads' + head_dim = in_channels // num_heads + + window_resolution = (window_size, window_size) + self.attn = LeAttention( + in_channels, + head_dim, + num_heads, + attn_ratio=1, + resolution=window_resolution) + + mlp_hidden_dim = int(in_channels * mlp_ratio) + self.mlp = MLP( + in_channels=in_channels, + hidden_channels=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.local_conv = ConvBN2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=local_conv_size, + stride=1, + padding=local_conv_size // 2, + groups=in_channels) + + def forward(self, x): + H, W = self.img_size + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - + H % self.window_size) % self.window_size + pad_r = (self.window_size - + W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, + C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + +class BasicStage(BaseModule): + """Basic Stage for TinyViT. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + depth (int): The number of blocks in the stage. + num_heads (int): The number of heads in the multi-head attention. + window_size (int): The size of the window. + mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float): Probability of an element to be zeroed. + Default: 0. + drop_path (float): The drop path of the block. + Default: 0. + downsample (None | nn.Module): The downsample operation. + Default: None. + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + use_checkpoint=False, + local_conv_size=3, + out_channels=None, + act_cfg=dict(type='GELU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.use_checkpoint = use_checkpoint + # build blocks + self.blocks = ModuleList([ + TinyViTBlock( + in_channels=in_channels, + resolution=resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + local_conv_size=local_conv_size, + act_cfg=act_cfg, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path) + for i in range(depth) + ]) + + # build patch merging layer + if downsample is not None: + self.downsample = downsample( + resolution=resolution, + in_channels=in_channels, + out_channels=out_channels, + act_cfg=act_cfg) + self.resolution = self.downsample.out_resolution + else: + self.downsample = None + self.resolution = resolution + + def forward(self, x): + for block in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +@MODELS.register_module() +class TinyViT(BaseBackbone): + """TinyViT. + A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation + for Small Vision Transformers`_ + + Inspiration from + https://github.com/microsoft/Cream/blob/main/TinyViT + + Args: + arch (str | dict): The architecture of TinyViT. + Default: '5m'. + img_size (tuple | int): The resolution of the input image. + Default: (224, 224) + window_size (list): The size of the window. + Default: [7, 7, 14, 7] + in_channels (int): The number of input channels. + Default: 3. + depths (list[int]): The depth of each stage. + Default: [2, 2, 6, 2]. + mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_rate (float): Probability of an element to be zeroed. + Default: 0. + drop_path_rate (float): The drop path of the block. + Default: 0.1. + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. + mbconv_expand_ratio (int): The expand ratio of the mbconv. + Default: 4.0 + local_conv_size (int): The size of the local conv. + Default: 3. + layer_lr_decay (float): The layer lr decay. + Default: 1.0 + out_indices (int | list[int]): Output from which stages. + Default: -1 + frozen_stages (int | list[int]): Stages to be frozen (all param fixed). + Default: -0 + gap_before_final_nrom (bool): Whether to add a gap before the final + norm. Default: True. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + arch_settings = { + '5m': { + 'channels': [64, 128, 160, 320], + 'num_heads': [2, 4, 5, 10], + 'depths': [2, 2, 6, 2], + }, + '11m': { + 'channels': [64, 128, 256, 448], + 'num_heads': [2, 4, 8, 14], + 'depths': [2, 2, 6, 2], + }, + '21m': { + 'channels': [96, 192, 384, 576], + 'num_heads': [3, 6, 12, 18], + 'depths': [2, 2, 6, 2], + }, + } + + def __init__(self, + arch='5m', + img_size=(224, 224), + window_size=[7, 7, 14, 7], + in_channels=3, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavaiable arch, please choose from ' \ + f'({set(self.arch_settings)} or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'channels' in arch and 'num_heads' in arch and \ + 'depths' in arch, 'The arch dict must have' \ + f'"channels", "num_heads", "window_sizes" ' \ + f'keys, but got {arch.keys()}' + + self.channels = arch['channels'] + self.num_heads = arch['num_heads'] + self.widow_sizes = window_size + self.img_size = img_size + self.depths = arch['depths'] + + self.num_stages = len(self.channels) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + self.layer_lr_decay = layer_lr_decay + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dim=self.channels[0], + resolution=self.img_size, + act_cfg=dict(type='GELU')) + patches_resolution = self.patch_embed.patches_resolution + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + + # build stages + self.stages = ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channel = self.channels[i] + curr_resolution = (patches_resolution[0] // (2**i), + patches_resolution[1] // (2**i)) + drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])] + downsample = PatchMerging if (i < self.num_stages - 1) else None + out_channels = self.channels[min(i + 1, self.num_stages - 1)] + if i >= 1: + stage = BasicStage( + in_channels=channel, + resolution=curr_resolution, + depth=depth, + num_heads=self.num_heads[i], + window_size=self.widow_sizes[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=drop_path, + downsample=downsample, + use_checkpoint=use_checkpoint, + local_conv_size=local_conv_size, + out_channels=out_channels, + act_cfg=act_cfg) + else: + stage = ConvStage( + in_channels=channel, + resolution=curr_resolution, + depth=depth, + act_cfg=act_cfg, + drop_path=drop_path, + downsample=downsample, + use_checkpoint=use_checkpoint, + out_channels=out_channels, + conv_expand_ratio=mbconv_expand_ratio) + self.stages.append(stage) + + # add output norm + if i in self.out_indices: + norm_layer = build_norm_layer(norm_cfg, out_channels)[1] + self.add_module(f'norm{i}', norm_layer) + + def set_layer_lr_decay(self, layer_lr_decay): + # TODO: add layer_lr_decay + pass + + def forward(self, x): + outs = [] + x = self.patch_embed(x) + + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean(1) + outs.append(norm_layer(gap)) + else: + out = norm_layer(x) + # convert the (B,L,C) format into (B,C,H,W) format + # which would be better for the downstream tasks. + B, L, C = out.shape + out = out.view(B, *stage.resolution, C) + outs.append(out.permute(0, 3, 1, 2)) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(TinyViT, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/tnt.py b/mmpretrain/models/backbones/tnt.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b241c1f6bc398157793748b7a457f0836daedb --- /dev/null +++ b/mmpretrain/models/backbones/tnt.py @@ -0,0 +1,368 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import to_2tuple +from .base_backbone import BaseBackbone + + +class TransformerBlock(BaseModule): + """Implement a transformer block in TnTLayer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. + Default: 4 + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0. + drop_path_rate (float): stochastic depth rate. Default 0. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + qkv_bias (bool): Enable bias for qkv if True. Default False + act_cfg (dict): The activation config for FFNs. Defaults to GELU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) or (n, batch, embed_dim). + (batch, n, embed_dim) is common case in CV. Defaults to False + init_cfg (dict, optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + embed_dims, + num_heads, + ffn_ratio=4, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + init_cfg=None): + super(TransformerBlock, self).__init__(init_cfg=init_cfg) + + self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first) + + self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=embed_dims * ffn_ratio, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + if not qkv_bias: + self.attn.attn.in_proj_bias = None + + def forward(self, x): + x = self.attn(self.norm_attn(x), identity=x) + x = self.ffn(self.norm_ffn(x), identity=x) + return x + + +class TnTLayer(BaseModule): + """Implement one encoder layer in Transformer in Transformer. + + Args: + num_pixel (int): The pixel number in target patch transformed with + a linear projection in inner transformer + embed_dims_inner (int): Feature dimension in inner transformer block + embed_dims_outer (int): Feature dimension in outer transformer block + num_heads_inner (int): Parallel attention heads in inner transformer. + num_heads_outer (int): Parallel attention heads in outer transformer. + inner_block_cfg (dict): Extra config of inner transformer block. + Defaults to empty dict. + outer_block_cfg (dict): Extra config of outer transformer block. + Defaults to empty dict. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + init_cfg (dict, optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + num_pixel, + embed_dims_inner, + embed_dims_outer, + num_heads_inner, + num_heads_outer, + inner_block_cfg=dict(), + outer_block_cfg=dict(), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(TnTLayer, self).__init__(init_cfg=init_cfg) + + self.inner_block = TransformerBlock( + embed_dims=embed_dims_inner, + num_heads=num_heads_inner, + **inner_block_cfg) + + self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1] + self.projection = nn.Linear( + embed_dims_inner * num_pixel, embed_dims_outer, bias=True) + + self.outer_block = TransformerBlock( + embed_dims=embed_dims_outer, + num_heads=num_heads_outer, + **outer_block_cfg) + + def forward(self, pixel_embed, patch_embed): + pixel_embed = self.inner_block(pixel_embed) + + B, N, C = patch_embed.size() + patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection( + self.norm_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = self.outer_block(patch_embed) + + return pixel_embed, patch_embed + + +class PixelEmbed(BaseModule): + """Image to Pixel Embedding. + + Args: + img_size (int | tuple): The size of input image + patch_size (int): The size of one patch + in_channels (int): The num of input channels + embed_dims_inner (int): The num of channels of the target patch + transformed with a linear projection in inner transformer + stride (int): The stride of the conv2d layer. We use a conv2d layer + and a unfold layer to implement image to pixel embedding. + init_cfg (dict, optional): Initialization config dict + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims_inner=48, + stride=4, + init_cfg=None): + super(PixelEmbed, self).__init__(init_cfg=init_cfg) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # patches_resolution property necessary for resizing + # positional embedding + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + num_patches = patches_resolution[0] * patches_resolution[1] + + self.img_size = img_size + self.num_patches = num_patches + self.embed_dims_inner = embed_dims_inner + + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] + self.new_patch_size = new_patch_size + + self.proj = nn.Conv2d( + in_channels, + self.embed_dims_inner, + kernel_size=7, + padding=3, + stride=stride) + self.unfold = nn.Unfold( + kernel_size=new_patch_size, stride=new_patch_size) + + def forward(self, x, pixel_pos): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model " \ + f'({self.img_size[0]}*{self.img_size[1]}).' + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, + 2).reshape(B * self.num_patches, self.embed_dims_inner, + self.new_patch_size[0], + self.new_patch_size[1]) + x = x + pixel_pos + x = x.reshape(B * self.num_patches, self.embed_dims_inner, + -1).transpose(1, 2) + return x + + +@MODELS.register_module() +class TNT(BaseBackbone): + """Transformer in Transformer. + + A PyTorch implement of: `Transformer in Transformer + `_ + + Inspiration from + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size. Defaults to 224 + patch_size (int | tuple): The patch size. Deault to 16 + in_channels (int): Number of input channels. Defaults to 3 + ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. + Default: 4 + qkv_bias (bool): Enable bias for qkv if True. Default False + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0. + drop_path_rate (float): stochastic depth rate. Default 0. + act_cfg (dict): The activation config for FFNs. Defaults to GELU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + first_stride (int): The stride of the conv2d layer. We use a conv2d + layer and a unfold layer to implement image to pixel embedding. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + init_cfg (dict, optional): Initialization config dict + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims_outer': 384, + 'embed_dims_inner': 24, + 'num_layers': 12, + 'num_heads_outer': 6, + 'num_heads_inner': 4 + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims_outer': 640, + 'embed_dims_inner': 40, + 'num_layers': 12, + 'num_heads_outer': 10, + 'num_heads_inner': 4 + }) + } + + def __init__(self, + arch='b', + img_size=224, + patch_size=16, + in_channels=3, + ffn_ratio=4, + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + first_stride=4, + num_fcs=2, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ]): + super(TNT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims_outer', 'embed_dims_inner', 'num_layers', + 'num_heads_inner', 'num_heads_outer' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims_inner = self.arch_settings['embed_dims_inner'] + self.embed_dims_outer = self.arch_settings['embed_dims_outer'] + # embed_dims for consistency with other models + self.embed_dims = self.embed_dims_outer + self.num_layers = self.arch_settings['num_layers'] + self.num_heads_inner = self.arch_settings['num_heads_inner'] + self.num_heads_outer = self.arch_settings['num_heads_outer'] + + self.pixel_embed = PixelEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims_inner=self.embed_dims_inner, + stride=first_stride) + num_patches = self.pixel_embed.num_patches + self.num_patches = num_patches + new_patch_size = self.pixel_embed.new_patch_size + num_pixel = new_patch_size[0] * new_patch_size[1] + + self.norm1_proj = build_norm_layer(norm_cfg, num_pixel * + self.embed_dims_inner)[1] + self.projection = nn.Linear(num_pixel * self.embed_dims_inner, + self.embed_dims_outer) + self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer)) + self.patch_pos = nn.Parameter( + torch.zeros(1, num_patches + 1, self.embed_dims_outer)) + self.pixel_pos = nn.Parameter( + torch.zeros(1, self.embed_dims_inner, new_patch_size[0], + new_patch_size[1])) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, self.num_layers) + ] # stochastic depth decay rule + self.layers = ModuleList() + for i in range(self.num_layers): + block_cfg = dict( + ffn_ratio=ffn_ratio, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + batch_first=True) + self.layers.append( + TnTLayer( + num_pixel=num_pixel, + embed_dims_inner=self.embed_dims_inner, + embed_dims_outer=self.embed_dims_outer, + num_heads_inner=self.num_heads_inner, + num_heads_outer=self.num_heads_outer, + inner_block_cfg=block_cfg, + outer_block_cfg=block_cfg, + norm_cfg=norm_cfg)) + + self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.patch_pos, std=.02) + trunc_normal_(self.pixel_pos, std=.02) + + def forward(self, x): + B = x.shape[0] + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj( + self.projection( + self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat( + (self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.drop_after_pos(patch_embed) + + for layer in self.layers: + pixel_embed, patch_embed = layer(pixel_embed, patch_embed) + + patch_embed = self.norm(patch_embed) + return (patch_embed[:, 0], ) diff --git a/mmpretrain/models/backbones/twins.py b/mmpretrain/models/backbones/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..be55c02db1daa5cb37760f2066448b3fca2cb893 --- /dev/null +++ b/mmpretrain/models/backbones/twins.py @@ -0,0 +1,721 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import ConditionalPositionEncoding, MultiheadAttention + + +class GlobalSubsampledAttention(MultiheadAttention): + """Global Sub-sampled Attention (GSA) module. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + sr_ratio (float): The ratio of spatial reduction in attention modules. + Defaults to 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + norm_cfg=dict(type='LN'), + qkv_bias=True, + sr_ratio=1, + **kwargs): + super(GlobalSubsampledAttention, + self).__init__(embed_dims, num_heads, **kwargs) + + self.qkv_bias = qkv_bias + self.q = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias) + self.kv = nn.Linear(self.input_dims, embed_dims * 2, bias=qkv_bias) + + # remove self.qkv, here split into self.q, self.kv + delattr(self, 'qkv') + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + # use a conv as the spatial-reduction operation, the kernel_size + # and stride in conv are equal to the sr_ratio. + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + assert H * W == N, 'The product of h and w of hw_shape must be N, ' \ + 'which is the 2nd dim number of the input Tensor x.' + + q = self.q(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x = x.permute(0, 2, 1).reshape(B, C, *hw_shape) # BNC_2_BCHW + x = self.sr(x) + x = x.reshape(B, C, -1).permute(0, 2, 1) # BCHW_2_BNC + x = self.norm(x) + + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GlobalSubsampledAttention(GSA). + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): The ratio of spatial reduction in attention modules. + Defaults to 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, \ + f'dim {embed_dims} should be divided by num_heads {num_heads}' + + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + x = x.view(B, H, W, C) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(B, _h, self.window_size, _w, self.window_size, + C).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(B, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, C // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.window_size, + self.window_size, C) + x = attn.transpose(2, 3).reshape(B, _h * self.window_size, + _w * self.window_size, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer with LocallyGroupedSelfAttention(LSA). + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +@MODELS.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + arch (dict, str): PCPVT architecture, a str value in arch zoo or a + detailed configuration dict with 7 keys, and the length of all the + values in dict should be the same: + + - depths (List[int]): The number of encoder layers in each stage. + - embed_dims (List[int]): Embedding dimension in each stage. + - patch_sizes (List[int]): The patch sizes in each stage. + - num_heads (List[int]): Numbers of attention head in each stage. + - strides (List[int]): The strides in each stage. + - mlp_ratios (List[int]): The ratios of mlp in each stage. + - sr_ratios (List[int]): The ratios of GSA-encoder layers in each + stage. + + in_channels (int): Number of input channels. Defaults to 3. + out_indices (tuple[int]): Output from which stages. + Defaults to ``(3, )``. + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0 + drop_path_rate (float): Stochastic depth rate. Defaults to 0.0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + norm_after_stage(bool, List[bool]): Add extra norm after each stage. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import PCPVT + >>> import torch + >>> pcpvt_cfg = {'arch': "small", + >>> 'norm_after_stage': [False, False, False, True]} + >>> model = PCPVT(**pcpvt_cfg) + >>> x = torch.rand(1, 3, 224, 224) + >>> outputs = model(x) + >>> print(outputs[-1].shape) + torch.Size([1, 512, 7, 7]) + >>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True] + >>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3) + >>> model = PCPVT(**pcpvt_cfg) + >>> outputs = model(x) + >>> for feat in outputs: + >>> print(feat.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 320, 14, 14]) + torch.Size([1, 512, 7, 7]) + """ + arch_zoo = { + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 4, 6, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 4, 18, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 8, 27, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + } # yapf: disable + + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides', + 'mlp_ratios', 'sr_ratios' + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + norm_after_stage=False, + init_cfg=None): + super(PCPVT, self).__init__(init_cfg=init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + assert isinstance(arch, dict) and ( + set(arch) == self.essential_keys + ), f'Custom arch needs a dict with keys {self.essential_keys}.' + self.arch_settings = arch + + self.depths = self.arch_settings['depths'] + self.embed_dims = self.arch_settings['embed_dims'] + self.patch_sizes = self.arch_settings['patch_sizes'] + self.strides = self.arch_settings['strides'] + self.mlp_ratios = self.arch_settings['mlp_ratios'] + self.num_heads = self.arch_settings['num_heads'] + self.sr_ratios = self.arch_settings['sr_ratios'] + + self.num_extra_tokens = 0 # there is no cls-token in Twins + self.num_stage = len(self.depths) + for key, value in self.arch_settings.items(): + assert isinstance(value, list) and len(value) == self.num_stage, ( + 'Length of setting item in arch dict must be type of list and' + ' have the same length.') + + # patch_embeds + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.stages = ModuleList() + + for i in range(self.num_stage): + # use in_channels of the model in the first stage + if i == 0: + stage_in_channels = in_channels + else: + stage_in_channels = self.embed_dims[i - 1] + + self.patch_embeds.append( + PatchEmbed( + in_channels=stage_in_channels, + embed_dims=self.embed_dims[i], + conv_type='Conv2d', + kernel_size=self.patch_sizes[i], + stride=self.strides[i], + padding='corner', + norm_cfg=dict(type='LN'))) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + # PEGs + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in self.embed_dims + ]) + + # stochastic depth + total_depth = sum(self.depths) + self.dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(self.depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=self.embed_dims[k], + num_heads=self.num_heads[k], + feedforward_channels=self.mlp_ratios[k] * + self.embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=self.dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=norm_cfg, + sr_ratio=self.sr_ratios[k]) for i in range(self.depths[k]) + ]) + self.stages.append(_block) + cur += self.depths[k] + + self.out_indices = out_indices + + assert isinstance(norm_after_stage, (bool, list)) + if isinstance(norm_after_stage, bool): + self.norm_after_stage = [norm_after_stage] * self.num_stage + else: + self.norm_after_stage = norm_after_stage + assert len(self.norm_after_stage) == self.num_stage, \ + (f'Number of norm_after_stage({len(self.norm_after_stage)}) should' + f' be equal to the number of stages({self.num_stage}).') + + for i, has_norm in enumerate(self.norm_after_stage): + assert isinstance(has_norm, bool), 'norm_after_stage should be ' \ + 'bool or List[bool].' + if has_norm and norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm_after_stage{i}', norm_layer) + + def init_weights(self): + if self.init_cfg is not None: + super(PCPVT, self).init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(self.num_stage): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.stages[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + + norm_layer = getattr(self, f'norm_after_stage{i}') + x = norm_layer(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@MODELS.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + arch (dict, str): SVT architecture, a str value in arch zoo or a + detailed configuration dict with 8 keys, and the length of all the + values in dict should be the same: + + - depths (List[int]): The number of encoder layers in each stage. + - embed_dims (List[int]): Embedding dimension in each stage. + - patch_sizes (List[int]): The patch sizes in each stage. + - num_heads (List[int]): Numbers of attention head in each stage. + - strides (List[int]): The strides in each stage. + - mlp_ratios (List[int]): The ratios of mlp in each stage. + - sr_ratios (List[int]): The ratios of GSA-encoder layers in each + stage. + - windiow_sizes (List[int]): The window sizes in LSA-encoder layers + in each stage. + + in_channels (int): Number of input channels. Defaults to 3. + out_indices (tuple[int]): Output from which stages. + Defaults to (3, ). + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + drop_rate (float): Dropout rate. Defaults to 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Defaults to 0.0 + drop_path_rate (float): Stochastic depth rate. Defaults to 0.2. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + norm_after_stage(bool, List[bool]): Add extra norm after each stage. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SVT + >>> import torch + >>> svt_cfg = {'arch': "small", + >>> 'norm_after_stage': [False, False, False, True]} + >>> model = SVT(**svt_cfg) + >>> x = torch.rand(1, 3, 224, 224) + >>> outputs = model(x) + >>> print(outputs[-1].shape) + torch.Size([1, 512, 7, 7]) + >>> svt_cfg["out_indices"] = (0, 1, 2, 3) + >>> svt_cfg["norm_after_stage"] = [True, True, True, True] + >>> model = SVT(**svt_cfg) + >>> output = model(x) + >>> for feat in output: + >>> print(feat.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 320, 14, 14]) + torch.Size([1, 512, 7, 7]) + """ + arch_zoo = { + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 256, 512], + 'depths': [2, 2, 10, 4], + 'num_heads': [2, 4, 8, 16], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [96, 192, 384, 768], + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [128, 256, 512, 1024], + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + } # yapf: disable + + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides', + 'mlp_ratios', 'sr_ratios', 'window_sizes' + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.0, + norm_cfg=dict(type='LN'), + norm_after_stage=False, + init_cfg=None): + super(SVT, self).__init__(arch, in_channels, out_indices, qkv_bias, + drop_rate, attn_drop_rate, drop_path_rate, + norm_cfg, norm_after_stage, init_cfg) + + self.window_sizes = self.arch_settings['window_sizes'] + + for k in range(self.num_stage): + for i in range(self.depths[k]): + # in even-numbered layers of each stage, replace GSA with LSA + if i % 2 == 0: + ffn_channels = self.mlp_ratios[k] * self.embed_dims[k] + self.stages[k][i] = \ + LSAEncoderLayer( + embed_dims=self.embed_dims[k], + num_heads=self.num_heads[k], + feedforward_channels=ffn_channels, + drop_rate=drop_rate, + norm_cfg=norm_cfg, + attn_drop_rate=attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:k])+i], + qkv_bias=qkv_bias, + window_size=self.window_sizes[k]) diff --git a/mmpretrain/models/backbones/van.py b/mmpretrain/models/backbones/van.py new file mode 100644 index 0000000000000000000000000000000000000000..c34dc3362f84ffa39151219f038f0c74ee0242e8 --- /dev/null +++ b/mmpretrain/models/backbones/van.py @@ -0,0 +1,434 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class MixFFN(BaseModule): + """An implementation of MixFFN of VAN. Refer to + mmdetection/mmdet/models/backbones/pvt.py. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Depth-wise Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + init_cfg=None): + super(MixFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + + self.fc1 = Conv2d( + in_channels=embed_dims, + out_channels=feedforward_channels, + kernel_size=1) + self.dwconv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=feedforward_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=embed_dims, + kernel_size=1) + self.drop = nn.Dropout(ffn_drop) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LKA(BaseModule): + """Large Kernel Attention(LKA) of VAN. + + .. code:: text + DW_conv (depth-wise convolution) + | + | + DW_D_conv (depth-wise dilation convolution) + | + | + Transition Convolution (1×1 convolution) + + Args: + embed_dims (int): Number of input channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, init_cfg=None): + super(LKA, self).__init__(init_cfg=init_cfg) + + # a spatial local convolution (depth-wise convolution) + self.DW_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=5, + padding=2, + groups=embed_dims) + + # a spatial long-range convolution (depth-wise dilation convolution) + self.DW_D_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=7, + stride=1, + padding=9, + groups=embed_dims, + dilation=3) + + self.conv1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + u = x.clone() + attn = self.DW_conv(x) + attn = self.DW_D_conv(attn) + attn = self.conv1(attn) + + return u * attn + + +class SpatialAttention(BaseModule): + """Basic attention module in VANBloack. + + Args: + embed_dims (int): Number of input channels. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None): + super(SpatialAttention, self).__init__(init_cfg=init_cfg) + + self.proj_1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = LKA(embed_dims) + self.proj_2 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class VANBlock(BaseModule): + """A block of VAN. + + Args: + embed_dims (int): Number of input channels. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-2. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + ffn_ratio=4., + drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN', eps=1e-5), + layer_scale_init_value=1e-2, + init_cfg=None): + super(VANBlock, self).__init__(init_cfg=init_cfg) + self.out_channels = embed_dims + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + mlp_hidden_dim = int(embed_dims * ffn_ratio) + self.mlp = MixFFN( + embed_dims=embed_dims, + feedforward_channels=mlp_hidden_dim, + act_cfg=act_cfg, + ffn_drop=drop_rate) + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + + def forward(self, x): + identity = x + x = self.norm1(x) + x = self.attn(x) + if self.layer_scale_1 is not None: + x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + identity = x + x = self.norm2(x) + x = self.mlp(x) + if self.layer_scale_2 is not None: + x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + return x + + +class VANPatchEmbed(PatchEmbed): + """Image to Patch Embedding of VAN. + + The differences between VANPatchEmbed & PatchEmbed: + 1. Use BN. + 2. Do not use 'flatten' and 'transpose'. + """ + + def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs): + super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs) + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +@MODELS.register_module() +class VAN(BaseBackbone): + """Visual Attention Network. + + A PyTorch implement of : `Visual Attention Network + `_ + + Inspiration from + https://github.com/Visual-Attention-Network/VAN-Classification + + Args: + arch (str | dict): Visual Attention Network architecture. + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **ffn_ratios** (List[int]): The number of expansion ratio of + feedforward network hidden layer channels. + + Defaults to 'tiny'. + patch_sizes (List[int | tuple]): The patch size in patch embeddings. + Defaults to [7, 3, 3, 3]. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import VAN + >>> import torch + >>> cfg = dict(arch='tiny') + >>> model = VAN(**cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for out in outputs: + >>> print(out.size()) + (1, 256, 7, 7) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': [32, 64, 160, 256], + 'depths': [3, 3, 5, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [2, 2, 4, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 3, 12, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 5, 27, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + patch_sizes=[7, 3, 3, 3], + in_channels=3, + drop_rate=0., + drop_path_rate=0., + out_indices=(3, ), + frozen_stages=-1, + norm_eval=False, + norm_cfg=dict(type='LN'), + block_cfgs=dict(), + init_cfg=None): + super(VAN, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.ffn_ratios = self.arch_settings['ffn_ratios'] + self.num_stages = len(self.depths) + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + for i, depth in enumerate(self.depths): + patch_embed = VANPatchEmbed( + in_channels=in_channels if i == 0 else self.embed_dims[i - 1], + input_size=None, + embed_dims=self.embed_dims[i], + kernel_size=patch_sizes[i], + stride=patch_sizes[i] // 2 + 1, + padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), + norm_cfg=dict(type='BN')) + + blocks = ModuleList([ + VANBlock( + embed_dims=self.embed_dims[i], + ffn_ratio=self.ffn_ratios[i], + drop_rate=drop_rate, + drop_path_rate=dpr[cur_block_idx + j], + **block_cfgs) for j in range(depth) + ]) + cur_block_idx += depth + norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + + self.add_module(f'patch_embed{i + 1}', patch_embed) + self.add_module(f'blocks{i + 1}', blocks) + self.add_module(f'norm{i + 1}', norm) + + def train(self, mode=True): + super(VAN, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = getattr(self, f'patch_embed{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = getattr(self, f'blocks{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + blocks = getattr(self, f'blocks{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, hw_shape = patch_embed(x) + for block in blocks: + x = block(x) + x = x.flatten(2).transpose(1, 2) + x = norm(x) + x = x.reshape(-1, *hw_shape, + block.out_channels).permute(0, 3, 1, 2).contiguous() + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/vgg.py b/mmpretrain/models/backbones/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..026b916256cf56cdf75d348ee07b0ceceffd9751 --- /dev/null +++ b/mmpretrain/models/backbones/vgg.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +def make_vgg_layer(in_channels, + out_channels, + num_blocks, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dilation=1, + with_norm=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layer = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers.append(layer) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +@MODELS.register_module() +class VGG(BaseBackbone): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_norm (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int], optional): Output from which stages. + When it is None, the default behavior depends on whether + num_classes is specified. If num_classes <= 0, the default value is + (4, ), output the last feature map before classifier. If + num_classes > 0, the default value is (5, ), output the + classification score. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. + with_last_pool (bool): Whether to keep the last pooling before + classifier. Default: True. + """ + + # Parameters to build layers. Each element specifies the number of conv in + # each stage. For example, VGG11 contains 11 layers with learnable + # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, + # where 3 indicates the last three fully-connected layers. + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=False, + ceil_mode=False, + with_last_pool=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1., layer=['_BatchNorm']), + dict(type='Normal', std=0.01, layer=['Linear']) + ]): + super(VGG, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + + self.num_classes = num_classes + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + with_norm = norm_cfg is not None + + if out_indices is None: + out_indices = (5, ) if num_classes > 0 else (4, ) + assert max(out_indices) <= num_stages + self.out_indices = out_indices + + self.in_channels = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + out_channels = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.in_channels, + out_channels, + num_blocks, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dilation=dilation, + with_norm=with_norm, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.in_channels = out_channels + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + vgg_layers = getattr(self, self.module_name) + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + m = vgg_layers[j] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(VGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/vig.py b/mmpretrain/models/backbones/vig.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a7879bd99682c32cbd1e02079fe79e2c6a3d0a --- /dev/null +++ b/mmpretrain/models/backbones/vig.py @@ -0,0 +1,852 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modified from +# https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def get_2d_relative_pos_embed(embed_dim, grid_size): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, grid_size*grid_size] + """ + pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) + relative_pos = 2 * np.matmul(pos_embed, + pos_embed.transpose()) / pos_embed.shape[1] + return relative_pos + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def xy_pairwise_distance(x, y): + """Compute pairwise distance of a point cloud. + + Args: + x: tensor (batch_size, num_points, num_dims) + y: tensor (batch_size, num_points, num_dims) + Returns: + pairwise distance: (batch_size, num_points, num_points) + """ + with torch.no_grad(): + xy_inner = -2 * torch.matmul(x, y.transpose(2, 1)) + x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) + y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) + return x_square + xy_inner + y_square.transpose(2, 1) + + +def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): + """Get KNN based on the pairwise distance. + + Args: + x: (batch_size, num_dims, num_points, 1) + y: (batch_size, num_dims, num_points, 1) + k: int + relative_pos:Whether to use relative_pos + Returns: + nearest neighbors: + (batch_size, num_points, k) (batch_size, num_points, k) + """ + with torch.no_grad(): + x = x.transpose(2, 1).squeeze(-1) + y = y.transpose(2, 1).squeeze(-1) + batch_size, n_points, n_dims = x.shape + dist = xy_pairwise_distance(x.detach(), y.detach()) + if relative_pos is not None: + dist += relative_pos + _, nn_idx = torch.topk(-dist, k=k) + center_idx = torch.arange( + 0, n_points, device=x.device).repeat(batch_size, k, + 1).transpose(2, 1) + return torch.stack((nn_idx, center_idx), dim=0) + + +class DenseDilated(nn.Module): + """Find dilated neighbor from neighbor list. + + edge_index: (2, batch_size, num_points, k) + """ + + def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): + super(DenseDilated, self).__init__() + self.dilation = dilation + self.use_stochastic = use_stochastic + self.epsilon = epsilon + self.k = k + + def forward(self, edge_index): + if self.use_stochastic: + if torch.rand(1) < self.epsilon and self.training: + num = self.k * self.dilation + randnum = torch.randperm(num)[:self.k] + edge_index = edge_index[:, :, :, randnum] + else: + edge_index = edge_index[:, :, :, ::self.dilation] + else: + edge_index = edge_index[:, :, :, ::self.dilation] + return edge_index + + +class DenseDilatedKnnGraph(nn.Module): + """Find the neighbors' indices based on dilated knn.""" + + def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): + super(DenseDilatedKnnGraph, self).__init__() + self.dilation = dilation + self.use_stochastic = use_stochastic + self.epsilon = epsilon + self.k = k + self._dilated = DenseDilated(k, dilation, use_stochastic, epsilon) + + def forward(self, x, y=None, relative_pos=None): + if y is not None: + x = F.normalize(x, p=2.0, dim=1) + y = F.normalize(y, p=2.0, dim=1) + + edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, + relative_pos) + else: + x = F.normalize(x, p=2.0, dim=1) + y = x.clone() + + edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, + relative_pos) + return self._dilated(edge_index) + + +class BasicConv(Sequential): + + def __init__(self, + channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True, + drop=0.): + m = [] + for i in range(1, len(channels)): + m.append( + nn.Conv2d( + channels[i - 1], + channels[i], + 1, + bias=graph_conv_bias, + groups=4)) + if norm_cfg is not None: + m.append(build_norm_layer(norm_cfg, channels[-1])) + if act_cfg is not None: + m.append(build_activation_layer(act_cfg)) + if drop > 0: + m.append(nn.Dropout2d(drop)) + + super(BasicConv, self).__init__(*m) + + +def batched_index_select(x, idx): + r"""fetches neighbors features from a given neighbor idx + + Args: + x (Tensor): input feature Tensor + :math: + `\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. + idx (Tensor): edge_idx + :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. + Returns: + Tensor: output neighbors features + :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. + """ + batch_size, num_dims, num_vertices_reduced = x.shape[:3] + _, num_vertices, k = idx.shape + idx_base = torch.arange( + 0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced + idx = idx + idx_base + idx = idx.contiguous().view(-1) + + x = x.transpose(2, 1) + feature = x.contiguous().view(batch_size * num_vertices_reduced, + -1)[idx, :] + feature = feature.view(batch_size, num_vertices, k, + num_dims).permute(0, 3, 1, 2).contiguous() + return feature + + +class MRConv2d(nn.Module): + """Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) + for dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(MRConv2d, self).__init__() + self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + + def forward(self, x, edge_index, y=None): + x_i = batched_index_select(x, edge_index[1]) + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) + b, c, n, _ = x.shape + x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], + dim=2).reshape(b, 2 * c, n, _) + return self.nn(x) + + +class EdgeConv2d(nn.Module): + """Edge convolution layer (with activation, batch normalization) for dense + data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(EdgeConv2d, self).__init__() + self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + + def forward(self, x, edge_index, y=None): + x_i = batched_index_select(x, edge_index[1]) + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + max_value, _ = torch.max( + self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) + return max_value + + +class GraphSAGE(nn.Module): + """GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) + for dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GraphSAGE, self).__init__() + self.nn1 = BasicConv([in_channels, in_channels], act_cfg, norm_cfg, + graph_conv_bias) + self.nn2 = BasicConv([in_channels * 2, out_channels], act_cfg, + norm_cfg, graph_conv_bias) + + def forward(self, x, edge_index, y=None): + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True) + return self.nn2(torch.cat([x, x_j], dim=1)) + + +class GINConv2d(nn.Module): + """GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for + dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GINConv2d, self).__init__() + self.nn = BasicConv([in_channels, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + eps_init = 0.0 + self.eps = nn.Parameter(torch.Tensor([eps_init])) + + def forward(self, x, edge_index, y=None): + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j = torch.sum(x_j, -1, keepdim=True) + return self.nn((1 + self.eps) * x + x_j) + + +class GraphConv2d(nn.Module): + """Static graph convolution layer.""" + + def __init__(self, + in_channels, + out_channels, + graph_conv_type, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GraphConv2d, self).__init__() + if graph_conv_type == 'edge': + self.gconv = EdgeConv2d(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + elif graph_conv_type == 'mr': + self.gconv = MRConv2d(in_channels, out_channels, act_cfg, norm_cfg, + graph_conv_bias) + elif graph_conv_type == 'sage': + self.gconv = GraphSAGE(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + elif graph_conv_type == 'gin': + self.gconv = GINConv2d(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + else: + raise NotImplementedError( + 'graph_conv_type:{} is not supported'.format(graph_conv_type)) + + def forward(self, x, edge_index, y=None): + return self.gconv(x, edge_index, y) + + +class DyGraphConv2d(GraphConv2d): + """Dynamic graph convolution layer.""" + + def __init__(self, + in_channels, + out_channels, + k=9, + dilation=1, + graph_conv_type='mr', + act_cfg=dict(type='GELU'), + norm_cfg=None, + graph_conv_bias=True, + use_stochastic=False, + epsilon=0.2, + r=1): + super(DyGraphConv2d, + self).__init__(in_channels, out_channels, graph_conv_type, + act_cfg, norm_cfg, graph_conv_bias) + self.k = k + self.d = dilation + self.r = r + self.dilated_knn_graph = DenseDilatedKnnGraph(k, dilation, + use_stochastic, epsilon) + + def forward(self, x, relative_pos=None): + B, C, H, W = x.shape + y = None + if self.r > 1: + y = F.avg_pool2d(x, self.r, self.r) + y = y.reshape(B, C, -1, 1).contiguous() + x = x.reshape(B, C, -1, 1).contiguous() + edge_index = self.dilated_knn_graph(x, y, relative_pos) + x = super(DyGraphConv2d, self).forward(x, edge_index, y) + return x.reshape(B, -1, H, W).contiguous() + + +class Grapher(nn.Module): + """Grapher module with graph convolution and fc layers.""" + + def __init__(self, + in_channels, + k=9, + dilation=1, + graph_conv_type='mr', + act_cfg=dict(type='GELU'), + norm_cfg=None, + graph_conv_bias=True, + use_stochastic=False, + epsilon=0.2, + r=1, + n=196, + drop_path=0.0, + relative_pos=False): + super(Grapher, self).__init__() + self.channels = in_channels + self.n = n + self.r = r + self.fc1 = Sequential( + nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), in_channels), + ) + self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, k, + dilation, graph_conv_type, act_cfg, + norm_cfg, graph_conv_bias, + use_stochastic, epsilon, r) + self.fc2 = Sequential( + nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), in_channels), + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.relative_pos = None + if relative_pos: + relative_pos_tensor = torch.from_numpy( + np.float32( + get_2d_relative_pos_embed(in_channels, int( + n**0.5)))).unsqueeze(0).unsqueeze(1) + relative_pos_tensor = F.interpolate( + relative_pos_tensor, + size=(n, n // (r * r)), + mode='bicubic', + align_corners=False) + self.relative_pos = nn.Parameter( + -relative_pos_tensor.squeeze(1), requires_grad=False) + + def _get_relative_pos(self, relative_pos, H, W): + if relative_pos is None or H * W == self.n: + return relative_pos + else: + N = H * W + N_reduced = N // (self.r * self.r) + return F.interpolate( + relative_pos.unsqueeze(0), size=(N, N_reduced), + mode='bicubic').squeeze(0) + + def forward(self, x): + B, C, H, W = x.shape + relative_pos = self._get_relative_pos(self.relative_pos, H, W) + shortcut = x + x = self.fc1(x) + x = self.graph_conv(x, relative_pos) + x = self.fc2(x) + x = self.drop_path(x) + shortcut + return x + + +class FFN(nn.Module): + """"out_features = out_features or in_features\n + hidden_features = hidden_features or in_features""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop_path=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = Sequential( + nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), hidden_features), + ) + self.act = build_activation_layer(act_cfg) + self.fc2 = Sequential( + nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), out_features), + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop_path(x) + shortcut + return x + + +@MODELS.register_module() +class Vig(BaseBackbone): + """Vision GNN backbone. + + A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes + `_. + + Modified from the official implementation + https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch + + Args: + arch(str): Vision GNN architecture, + choose from 'tiny', 'small' and 'base'. + in_channels (int): The number of channels of input images. + Defaults to 3. + k (int): The number of KNN's k. Defaults to 9. + out_indices (Sequence | int): Output from which blocks. + Defaults to -1, means the last block. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='GELU'))``. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN', eps=1e-6)``. + graph_conv_bias (bool): Whether to use bias in the convolution + layers in Grapher. Defaults to True. + graph_conv_type (str): The type of graph convolution,choose + from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. + epsilon (float): Probability of random arrangement in KNN. It only + works when ``use_dilation=True`` and ``use_stochastic=True``. + Defaults to 0.2. + use_dilation(bool): Whether to use dilation in KNN. Defaults to True. + use_stochastic(bool): Whether to use stochastic in KNN. + Defaults to False. + drop_path (float): stochastic depth rate. Default 0.0 + relative_pos(bool): Whether to use relative position embedding. + Defaults to False. + norm_eval (bool): Whether to set the normalization layer to eval mode. + Defaults to False. + frozen_stages (int): Blocks to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): The initialization configs. + Defaults to None. + """ # noqa: E501 + + arch_settings = { + 'tiny': dict(num_blocks=12, channels=192), + 'small': dict(num_blocks=16, channels=320), + 'base': dict(num_blocks=16, channels=640), + } + + def __init__(self, + arch, + in_channels=3, + k=9, + out_indices=-1, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + graph_conv_bias=True, + graph_conv_type='mr', + epsilon=0.2, + use_dilation=True, + use_stochastic=False, + drop_path=0., + relative_pos=False, + norm_eval=False, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + arch = self.arch_settings[arch] + self.num_blocks = arch['num_blocks'] + channels = arch['channels'] + + if isinstance(out_indices, int): + out_indices = [out_indices] + elif isinstance(out_indices, tuple): + out_indices = list(out_indices) + elif not isinstance(out_indices, list): + raise TypeError('"out_indices" must by a tuple, list or int, ' + f'get {type(out_indices)} instead.') + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_blocks + index + assert 0 <= out_indices[i] <= self.num_blocks, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.stem = Sequential( + nn.Conv2d(in_channels, channels // 8, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 8), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 8, channels // 4, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 4), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 4, channels // 2, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 2), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 2, channels, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels), + build_activation_layer(act_cfg), + nn.Conv2d(channels, channels, 3, stride=1, padding=1), + build_norm_layer(norm_cfg, channels), + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] + # number of knn's k + num_knn = [ + int(x.item()) for x in torch.linspace(k, 2 * k, self.num_blocks) + ] + max_dilation = 196 // max(num_knn) + + self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14)) + + self.blocks = ModuleList([ + Sequential( + Grapher( + in_channels=channels, + k=num_knn[i], + dilation=min(i // 4 + + 1, max_dilation) if use_dilation else 1, + graph_conv_type=graph_conv_type, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + graph_conv_bias=graph_conv_bias, + use_stochastic=use_stochastic, + epsilon=epsilon, + drop_path=dpr[i], + relative_pos=relative_pos), + FFN(in_features=channels, + hidden_features=channels * 4, + act_cfg=act_cfg, + drop_path=dpr[i])) for i in range(self.num_blocks) + ]) + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + def forward(self, inputs): + outs = [] + x = self.stem(inputs) + self.pos_embed + + for i, block in enumerate(self.blocks): + x = block(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + self.stem.eval() + for i in range(self.frozen_stages): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(Vig, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class PyramidVig(BaseBackbone): + """Pyramid Vision GNN backbone. + + A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes + `_. + + Modified from the official implementation + https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch + + Args: + arch (str): Vision GNN architecture, choose from 'tiny', + 'small' and 'base'. + in_channels (int): The number of channels of input images. + Defaults to 3. + k (int): The number of KNN's k. Defaults to 9. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='GELU'))``. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN')``. + graph_conv_bias (bool): Whether to use bias in the convolution + layers in Grapher. Defaults to True. + graph_conv_type (str): The type of graph convolution,choose + from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. + epsilon (float): Probability of random arrangement in KNN. It only + works when ``use_stochastic=True``. Defaults to 0.2. + use_stochastic (bool): Whether to use stochastic in KNN. + Defaults to False. + drop_path (float): stochastic depth rate. Default 0.0 + norm_eval (bool): Whether to set the normalization layer to eval mode. + Defaults to False. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): The initialization configs. + Defaults to None. + """ # noqa: E501 + arch_settings = { + 'tiny': dict(blocks=[2, 2, 6, 2], channels=[48, 96, 240, 384]), + 'small': dict(blocks=[2, 2, 6, 2], channels=[80, 160, 400, 640]), + 'medium': dict(blocks=[2, 2, 16, 2], channels=[96, 192, 384, 768]), + 'base': dict(blocks=[2, 2, 18, 2], channels=[128, 256, 512, 1024]), + } + + def __init__(self, + arch, + in_channels=3, + k=9, + out_indices=-1, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + graph_conv_bias=True, + graph_conv_type='mr', + epsilon=0.2, + use_stochastic=False, + drop_path=0., + norm_eval=False, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + arch = self.arch_settings[arch] + self.blocks = arch['blocks'] + self.num_blocks = sum(self.blocks) + self.num_stages = len(self.blocks) + channels = arch['channels'] + self.channels = channels + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] <= self.num_stages, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.stem = Sequential( + nn.Conv2d(in_channels, channels[0] // 2, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels[0] // 2), + build_activation_layer(act_cfg), + nn.Conv2d(channels[0] // 2, channels[0], 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels[0]), + build_activation_layer(act_cfg), + nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1), + build_norm_layer(norm_cfg, channels[0]), + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] + # number of knn's k + num_knn = [ + int(x.item()) for x in torch.linspace(k, k, self.num_blocks) + ] + max_dilation = 49 // max(num_knn) + + self.pos_embed = nn.Parameter( + torch.zeros(1, channels[0], 224 // 4, 224 // 4)) + HW = 224 // 4 * 224 // 4 + reduce_ratios = [4, 2, 1, 1] + + self.stages = ModuleList() + block_idx = 0 + for stage_idx, num_blocks in enumerate(self.blocks): + mid_channels = channels[stage_idx] + reduce_ratio = reduce_ratios[stage_idx] + blocks = [] + if stage_idx > 0: + blocks.append( + Sequential( + nn.Conv2d( + self.channels[stage_idx - 1], + mid_channels, + kernel_size=3, + stride=2, + padding=1), + build_norm_layer(norm_cfg, mid_channels), + )) + HW = HW // 4 + for _ in range(num_blocks): + blocks.append( + Sequential( + Grapher( + in_channels=mid_channels, + k=num_knn[block_idx], + dilation=min(block_idx // 4 + 1, max_dilation), + graph_conv_type=graph_conv_type, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + graph_conv_bias=graph_conv_bias, + use_stochastic=use_stochastic, + epsilon=epsilon, + r=reduce_ratio, + n=HW, + drop_path=dpr[block_idx], + relative_pos=True), + FFN(in_features=mid_channels, + hidden_features=mid_channels * 4, + act_cfg=act_cfg, + drop_path=dpr[block_idx]))) + block_idx += 1 + self.stages.append(Sequential(*blocks)) + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + def forward(self, inputs): + outs = [] + x = self.stem(inputs) + self.pos_embed + + for i, blocks in enumerate(self.stages): + x = blocks(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + self.stem.eval() + for i in range(self.frozen_stages): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(PyramidVig, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..21572f36b5a47a2a2f7e9832dc818acbe5840725 --- /dev/null +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -0,0 +1,530 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer, + resize_pos_embed, to_2tuple) +from .base_backbone import BaseBackbone + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + ffn_type (str): Select the type of ffn layers. Defaults to 'origin'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + layer_scale_init_value=0., + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + ffn_type='origin', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + layer_scale_init_value=layer_scale_init_value) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + if ffn_type == 'origin': + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + layer_scale_init_value=layer_scale_init_value) + elif ffn_type == 'swiglu_fused': + self.ffn = SwiGLUFFNFused( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + layer_scale_init_value=layer_scale_init_value) + else: + raise NotImplementedError + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +@MODELS.register_module() +class VisionTransformer(BaseBackbone): + """Vision Transformer. + + A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], + { + # The same as the implementation in MAE + # + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + **dict.fromkeys( + ['eva-g', 'eva-giant'], + { + # The implementation in EVA + # + 'embed_dims': 1408, + 'num_layers': 40, + 'num_heads': 16, + 'feedforward_channels': 6144 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small', 'dinov2-s', 'dinov2-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + **dict.fromkeys( + ['dinov2-g', 'dinov2-giant'], { + 'embed_dims': 1536, + 'num_layers': 40, + 'num_heads': 24, + 'feedforward_channels': 6144 + }), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + frozen_stages=-1, + interpolate_mode='bicubic', + layer_scale_init_value=0., + patch_cfg=dict(), + layer_cfgs=dict(), + pre_norm=False, + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm, # disable bias if pre_norm is used(e.g., CLIP) + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + if pre_norm: + self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims) + else: + self.pre_norm = nn.Identity() + + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + if self.out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(VisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze pre-norm + for param in self.pre_norm.parameters(): + param.requires_grad = False + # freeze cls_token + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers) and self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/vit_eva02.py b/mmpretrain/models/backbones/vit_eva02.py new file mode 100644 index 0000000000000000000000000000000000000000..20ec4b247bbdbfc209c353c8e001d34d71a3990c --- /dev/null +++ b/mmpretrain/models/backbones/vit_eva02.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer, + resize_pos_embed) +from .vision_transformer import VisionTransformer + + +class AttentionWithRoPE(BaseModule): + """Multi-head Attention Module with 2D sincos position embedding (RoPE). + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q and v. Note + that we follows the official implementation where ``k_bias`` + is 0. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + rope (:obj:`torch.nn.Module`, optional): If it is an object of the + ``RotaryEmbedding``, the rotation of the token position will be + performed before the softmax. Defaults to None. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + qkv_bias=True, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + init_cfg=None): + super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = qk_scale or self.head_dims**-0.5 + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.with_cls_token = with_cls_token + + self.rope = rope + + def forward(self, x, patch_resolution): + B, N, _ = x.shape + + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.rope: + if self.with_cls_token: + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t, patch_resolution) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] if self.with_cls_token else k + ro_k_t = self.rope(k_t, patch_resolution) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + else: + q = self.rope(q, patch_resolution) + k = self.rope(k, patch_resolution) + + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class EVA02EndcoderLayer(BaseModule): + """Implements one encoder EVA02EndcoderLayer in EVA02. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension of FFNs. + sub_ln (bool): Whether to add the sub layer normalization + in the attention module. Defaults to False. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool): enable bias for projection in the attention module + if True. Defaults to True. + rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object + in the attention module. Defaults to None. + drop_rate (float): Dropout rate in the mlp module. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + sub_ln=False, + attn_drop=0., + proj_drop=0., + qkv_bias=False, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + init_cfg=None): + super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims) + + self.attn = AttentionWithRoPE( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + rope=rope, + with_cls_token=with_cls_token) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate)) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims) + + if drop_rate > 0: + dropout_layer = dict(type='Dropout', drop_prob=drop_rate) + else: + dropout_layer = None + + if sub_ln: + ffn_norm = norm_cfg + else: + ffn_norm = None + + self.mlp = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + dropout_layer=dropout_layer, + norm_cfg=ffn_norm, + add_identity=False, + ) + + def forward(self, x, patch_resolution): + inputs = x + x = self.norm1(x) + x = self.attn(x, patch_resolution) + x = self.drop_path(x) + x = inputs + x + + inputs = x + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = inputs + x + + return x + + +@MODELS.register_module() +class ViTEVA02(VisionTransformer): + """EVA02 Vision Transformer. + + A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'tiny', 'small', 'base', 'large'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **mlp_ratio** (float): The ratio of the mlp module. + + Defaults to 'tiny'. + + sub_ln (bool): Whether to add the sub layer normalization in swiglu. + Defaults to False. + drop_rate (float): Probability of an element to be zeroed in the + mlp module. Defaults to 0. + attn_drop_rate (float): Probability of an element to be zeroed after + the softmax in the attention. Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed after + projection in the attention. Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + **kwargs(dict, optional): Other args for Vision Transformer. + """ + arch_zoo = { + **dict.fromkeys( + ['t', 'ti', 'tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': int(192 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': int(384 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': int(768 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': int(1024 * 4 * 2 / 3) + }) + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='tiny', + sub_ln=False, + drop_rate=0., + attn_drop_rate=0., + proj_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN'), + with_cls_token=True, + layer_cfgs=dict(), + **kwargs): + # set essential args for Vision Transformer + kwargs.update( + arch=arch, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + with_cls_token=with_cls_token) + super(ViTEVA02, self).__init__(**kwargs) + + self.num_heads = self.arch_settings['num_heads'] + + # Set RoPE + head_dim = self.embed_dims // self.num_heads + self.rope = RotaryEmbeddingFast( + embed_dims=head_dim, patch_resolution=self.patch_resolution) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self. + arch_settings['feedforward_channels'], + sub_ln=sub_ln, + norm_cfg=norm_cfg, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + rope=self.rope, + with_cls_token=with_cls_token, + drop_path_rate=dpr[i]) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(EVA02EndcoderLayer(**_layer_cfg)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, patch_resolution) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/vit_sam.py b/mmpretrain/models/backbones/vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb46a72adf26cb62b93d5538116bd74f36070fa --- /dev/null +++ b/mmpretrain/models/backbones/vit_sam.py @@ -0,0 +1,697 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import LayerNorm2d, build_norm_layer, resize_pos_embed, to_2tuple +from .base_backbone import BaseBackbone + + +def window_partition(x: torch.Tensor, + window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Partition into non-overlapping windows with padding if needed. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + x (torch.Tensor): Input tokens with [B, H, W, C]. + window_size (int): Window size. + + Returns: + Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + + - ``windows``: Windows after partition with + [B * num_windows, window_size, window_size, C]. + - ``(Hp, Wp)``: Padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int]) -> torch.Tensor: + """Window unpartition into original sequences and removing padding. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + x (torch.Tensor): Input tokens with + [B * num_windows, window_size, window_size, C]. + window_size (int): Window size. + pad_hw (tuple): Padded height and width (Hp, Wp). + hw (tuple): Original height and width (H, W) before padding. + + Returns: + torch.Tensor: Unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, + rel_pos: torch.Tensor) -> torch.Tensor: + """Get relative positional embeddings according to the relative positions + of query and key sizes. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + q_size (int): Size of query q. + k_size (int): Size of key k. + rel_pos (torch.Tensor): Relative position embeddings (L, C). + + Returns: + torch.Tensor: Extracted positional embeddings according to relative + positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode='linear', + ) + rel_pos_resized = rel_pos_resized.reshape(-1, + max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - + k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """Borrowed from https://github.com/facebookresearch/segment-anything/ + + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (torch.Tensor): Attention map. + q (torch.Tensor): Query q in the attention layer with shape + (B, q_h * q_w, C). + rel_pos_h (torch.Tensor): Relative position embeddings (Lh, C) for + height axis. + rel_pos_w (torch.Tensor): Relative position embeddings (Lw, C) for + width axis. + q_size (tuple): Spatial sequence size of query q with (q_h, q_w). + k_size (tuple): Spatial sequence size of key k with (k_h, k_w). + + Returns: + torch.Tensor: Attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) + rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to False. + input_size (int, optional): Input resolution for calculating the + relative positional parameter size. Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = head_embed_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dims, embed_dims) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert (input_size is not None), \ + 'Input size must be provided if using relative position embed.' + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_embed_dims)) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * input_size[1] - 1, head_embed_dims)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, + -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, + self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, + -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class TransformerEncoderLayer(BaseModule): + """Encoder layer with window attention in Vision Transformer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to False. + window_size (int): Window size for window attention. Defaults to 0. + input_size (int, optional): Input resolution for calculating the + relative positional parameter size. Defaults to None. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + use_rel_pos: bool = False, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.window_size = window_size + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = Attention( + embed_dims=embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + input_size=input_size if window_size == 0 else + (window_size, window_size), + ) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x): + shortcut = x + x = self.ln1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = shortcut + x + + x = self.ffn(self.ln2(x), identity=x) + return x + + +@MODELS.register_module() +class ViTSAM(BaseBackbone): + """Vision Transformer as image encoder used in SAM. + + A PyTorch implement of backbone: `Segment Anything + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'base', 'large', 'huge'. If use dict, it should have + below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + - **global_attn_indexes** (int): The index of layers with global + attention. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_channels (int): The num of output channels, if equal to 0, the + channel reduction layer is disabled. Defaults to 256. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + out_type (str): The type of output features. Please choose from + + - ``"raw"`` or ``"featmap"``: The feature map tensor from the + patch tokens with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + + Defaults to ``"raw"``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + use_abs_pos (bool): Whether to use absolute position embedding. + Defaults to True. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to True. + window_size (int): Window size for window attention. Defaults to 14. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + 'global_attn_indexes': [2, 5, 8, 11] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + 'global_attn_indexes': [5, 11, 17, 23] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120, + 'global_attn_indexes': [7, 15, 23, 31] + }), + } + OUT_TYPES = {'raw', 'featmap', 'avg_featmap'} + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_channels: int = 256, + out_indices: int = -1, + out_type: str = 'raw', + drop_rate: float = 0., + drop_path_rate: float = 0., + qkv_bias: bool = True, + use_abs_pos: bool = True, + use_rel_pos: bool = True, + window_size: int = 14, + norm_cfg: dict = dict(type='LN', eps=1e-6), + frozen_stages: int = -1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[dict] = None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.global_attn_indexes = self.arch_settings['global_attn_indexes'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + self.use_abs_pos = use_abs_pos + self.interpolate_mode = interpolate_mode + if use_abs_pos: + # Set position embedding + self.pos_embed = nn.Parameter( + torch.zeros(1, *self.patch_resolution, self.embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + if use_rel_pos: + self._register_load_state_dict_pre_hook( + self._prepare_relative_position) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + window_size=window_size + if i not in self.global_attn_indexes else 0, + input_size=self.patch_resolution, + use_rel_pos=use_rel_pos, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.out_channels = out_channels + if self.out_channels > 0: + self.channel_reduction = nn.Sequential( + nn.Conv2d( + self.embed_dims, + out_channels, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_channels, eps=1e-6), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_channels, eps=1e-6), + ) + + # freeze stages only when self.frozen_stages > 0 + self.frozen_stages = frozen_stages + if self.frozen_stages > 0: + self._freeze_stages() + + def init_weights(self): + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze channel_reduction module + if self.frozen_stages == self.num_layers and self.out_channels > 0: + m = self.channel_reduction + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + x = x.view(B, patch_resolution[0], patch_resolution[1], + self.embed_dims) + + if self.use_abs_pos: + # 'resize_pos_embed' only supports 'pos_embed' with ndim==3, but + # in ViTSAM, the 'pos_embed' has 4 dimensions (1, H, W, C), so it + # is flattened. Besides, ViTSAM doesn't have any extra token. + resized_pos_embed = resize_pos_embed( + self.pos_embed.flatten(1, 2), + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=0) + x = x + resized_pos_embed.view(1, *patch_resolution, + self.embed_dims) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i in self.out_indices: + # (B, H, W, C) -> (B, C, H, W) + x_reshape = x.permute(0, 3, 1, 2) + + if self.out_channels > 0: + x_reshape = self.channel_reduction(x_reshape) + outs.append(self._format_output(x_reshape)) + + return tuple(outs) + + def _format_output(self, x) -> torch.Tensor: + if self.out_type == 'raw' or self.out_type == 'featmap': + return x + elif self.out_type == 'avg_featmap': + # (B, C, H, W) -> (B, C, N) -> (B, N, C) + x = x.flatten(2).permute(0, 2, 1) + return x.mean(dim=1) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3] + pos_embed_shape = self.patch_embed.init_out_size + + flattened_pos_embed = state_dict[name].flatten(1, 2) + resized_pos_embed = resize_pos_embed(flattened_pos_embed, + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, 0) + state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape, + self.embed_dims) + + def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs): + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'rel_pos_' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + relative_position_pretrained = state_dict[ckpt_key] + relative_position_current = state_dict_model[key] + L1, _ = relative_position_pretrained.size() + L2, _ = relative_position_current.size() + if L1 != L2: + new_rel_pos = F.interpolate( + relative_position_pretrained.reshape(1, L1, + -1).permute( + 0, 2, 1), + size=L2, + mode='linear', + ) + new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info(f'Resize the {ckpt_key} from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos.shape}') + state_dict[ckpt_key] = new_rel_pos + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/xcit.py b/mmpretrain/models/backbones/xcit.py new file mode 100644 index 0000000000000000000000000000000000000000..392ebbedf457cc199b70afa1923ec0b698f7fd5b --- /dev/null +++ b/mmpretrain/models/backbones/xcit.py @@ -0,0 +1,770 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import ConvModule, DropPath +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, Sequential +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer, to_2tuple +from .base_backbone import BaseBackbone + +if digit_version(torch.__version__) < digit_version('1.8.0'): + floor_div = torch.floor_divide +else: + floor_div = partial(torch.div, rounding_mode='floor') + + +class ClassAttntion(BaseModule): + """Class Attention Module. + + A PyTorch implementation of Class Attention Module introduced by: + `Going deeper with Image Transformers `_ + + taken from + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + with slight modifications to do CA + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): The drop out rate for linear output weights. + Defaults to 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg=None): + + super(ClassAttntion, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # We only need to calculate query of cls token. + q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, + C // self.num_heads).permute( + 0, 2, 1, 3) + k = self.k(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + q = q * self.scale + v = self.v(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = self.proj(x_cls) + x_cls = self.proj_drop(x_cls) + + return x_cls + + +class PositionalEncodingFourier(BaseModule): + """Positional Encoding using a fourier kernel. + + A PyTorch implementation of Positional Encoding relying on + a fourier kernel introduced by: + `Attention is all you Need `_ + + Based on the `official XCiT code + `_ + + Args: + hidden_dim (int): The hidden feature dimension. Defaults to 32. + dim (int): The output feature dimension. Defaults to 768. + temperature (int): A control variable for position encoding. + Defaults to 10000. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + hidden_dim: int = 32, + dim: int = 768, + temperature: int = 10000, + init_cfg=None): + super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg) + + self.token_projection = ConvModule( + in_channels=hidden_dim * 2, + out_channels=dim, + kernel_size=1, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + self.eps = 1e-6 + + def forward(self, B: int, H: int, W: int): + device = self.token_projection.conv.weight.device + y_embed = torch.arange( + 1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float() + x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float() + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, device=device).float() + dim_t = floor_div(dim_t, 2) + dim_t = self.temperature**(2 * dim_t / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], + dim=4).flatten(3) + pos_y = torch.stack( + [pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos.repeat(B, 1, 1, 1) # (B, C, H, W) + + +class ConvPatchEmbed(BaseModule): + """Patch Embedding using multiple convolution layers. + + Args: + img_size (int, tuple): input image size. + Defaults to 224, means the size is 224*224. + patch_size (int): The patch size in conv patch embedding. + Defaults to 16. + in_channels (int): The input channels of this module. + Defaults to 3. + embed_dims (int): The feature dimension + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + img_size: Union[int, tuple] = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg) + img_size = to_2tuple(img_size) + num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + conv = partial( + ConvModule, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + layer = [] + if patch_size == 16: + layer.append( + conv(in_channels=in_channels, out_channels=embed_dims // 8)) + layer.append( + conv( + in_channels=embed_dims // 8, out_channels=embed_dims // 4)) + elif patch_size == 8: + layer.append( + conv(in_channels=in_channels, out_channels=embed_dims // 4)) + else: + raise ValueError('For patch embedding, the patch size must be 16 ' + f'or 8, but get patch size {self.patch_size}.') + + layer.append( + conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2)) + layer.append( + conv( + in_channels=embed_dims // 2, + out_channels=embed_dims, + act_cfg=None, + )) + + self.proj = Sequential(*layer) + + def forward(self, x: torch.Tensor): + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) # (B, N, C) + return x, (Hp, Wp) + + +class ClassAttentionBlock(BaseModule): + """Transformer block using Class Attention. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): The hidden dimension ratio for FFN. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + tokens_norm (bool): Whether to normalize all tokens or just the + cls_token in the CA. Defaults to False. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop=0., + attn_drop=0., + drop_path=0., + layer_scale_init_value=1., + tokens_norm=False, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + + super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, dim) + + self.attn = ClassAttntion( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, dim) + + self.ffn = FFN( + embed_dims=dim, + feedforward_channels=int(dim * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + ) + + if layer_scale_init_value > 0: + self.gamma1 = nn.Parameter(layer_scale_init_value * + torch.ones(dim)) + self.gamma2 = nn.Parameter(layer_scale_init_value * + torch.ones(dim)) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 + self.tokens_norm = tokens_norm + + def forward(self, x): + x_norm1 = self.norm1(x) + x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) + x = x + self.drop_path(self.gamma1 * x_attn) + if self.tokens_norm: + x = self.norm2(x) + else: + x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.ffn(cls_token, identity=0) + x = torch.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class LPI(BaseModule): + """Local Patch Interaction module. + + A PyTorch implementation of Local Patch Interaction module + as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers + `_ + + Local Patch Interaction module that allows explicit communication between + tokens in 3x3 windows to augment the implicit communication performed by + the block diagonal scatter attention. Implemented using 2 layers of + separable 3x3 convolutions with GeLU and BatchNorm2d + + Args: + in_features (int): The input channels. + out_features (int, optional): The output channels. Defaults to None. + kernel_size (int): The kernel_size in ConvModule. Defaults to 3. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_features: int, + out_features: Optional[int] = None, + kernel_size: int = 3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(LPI, self).__init__(init_cfg=init_cfg) + + out_features = out_features or in_features + padding = kernel_size // 2 + + self.conv1 = ConvModule( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + groups=in_features, + bias=True, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('conv', 'act', 'norm')) + + self.conv2 = ConvModule( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=out_features, + norm_cfg=None, + act_cfg=None) + + def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + return x + + +class XCA(BaseModule): + r"""Cross-Covariance Attention module. + + A PyTorch implementation of Cross-Covariance Attention module + as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers + `_ + + In Cross-Covariance Attention (XCA), the channels are updated using a + weighted sum. The weights are obtained from the (softmax normalized) + Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)` + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): The drop out rate for linear output weights. + Defaults to 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg=None): + super(XCA, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + # (qkv, B, num_heads, channels per head, N) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # Paper section 3.2 l2-Normalization and temperature scaling + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C) + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class XCABlock(BaseModule): + """Transformer block using XCA. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): The hidden dimension ratio for FFNs. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + bn_norm_cfg (dict): Config dict for batchnorm in LPI and + ConvPatchEmbed. Defaults to ``dict(type='BN')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + layer_scale_init_value: float = 1., + bn_norm_cfg=dict(type='BN'), + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(XCABlock, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, dim) + self.attn = XCA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = build_norm_layer(norm_cfg, dim) + self.local_mp = LPI( + in_features=dim, + norm_cfg=bn_norm_cfg, + act_cfg=act_cfg, + ) + + self.norm2 = build_norm_layer(norm_cfg, dim) + self.ffn = FFN( + embed_dims=dim, + feedforward_channels=int(dim * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + ) + + self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + + def forward(self, x, H: int, W: int): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + # NOTE official code has 3 then 2, so keeping it the same to be + # consistent with loaded weights See + # https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 + x = x + self.drop_path( + self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path( + self.gamma2 * self.ffn(self.norm2(x), identity=0)) + return x + + +@MODELS.register_module() +class XCiT(BaseBackbone): + """XCiT backbone. + + A PyTorch implementation of XCiT backbone introduced by: + `XCiT: Cross-Covariance Image Transformers + `_ + + Args: + img_size (int, tuple): Input image size. Defaults to 224. + patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. + embed_dims (int): Embedding dimension. Defaults to 768. + depth (int): depth of vision transformer. Defaults to 12. + cls_attn_layers (int): Depth of Class attention layers. + Defaults to 2. + num_heads (int): Number of attention heads. Defaults to 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_pos_embed (bool): Whether to use positional encoding. + Defaults to True. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + tokens_norm (bool): Whether to normalize all tokens or just the + cls_token in the CA. Defaults to False. + out_indices (Sequence[int]): Output from which layers. + Defaults to (-1, ). + frozen_stages (int): Layers to be frozen (all param fixed), and 0 + means to freeze the stem stage. Defaults to -1, which means + not freeze any parameters. + bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and + ConvPatchEmbed. Defaults to ``dict(type='BN')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + img_size: Union[int, tuple] = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + depth: int = 12, + cls_attn_layers: int = 2, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + use_pos_embed: bool = True, + layer_scale_init_value: float = 1., + tokens_norm: bool = False, + out_type: str = 'cls_token', + out_indices: Sequence[int] = (-1, ), + final_norm: bool = True, + frozen_stages: int = -1, + bn_norm_cfg=dict(type='BN'), + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=dict(type='TruncNormal', layer='Linear')): + super(XCiT, self).__init__(init_cfg=init_cfg) + + img_size = to_2tuple(img_size) + if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0): + raise ValueError(f'`patch_size` ({patch_size}) should divide ' + f'the image shape ({img_size}) evenly.') + + self.embed_dims = embed_dims + + assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token') + self.out_type = out_type + + self.patch_embed = ConvPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + norm_cfg=bn_norm_cfg, + act_cfg=act_cfg, + ) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.use_pos_embed = use_pos_embed + if use_pos_embed: + self.pos_embed = PositionalEncodingFourier(dim=embed_dims) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.xca_layers = nn.ModuleList() + self.ca_layers = nn.ModuleList() + self.num_layers = depth + cls_attn_layers + + for _ in range(depth): + self.xca_layers.append( + XCABlock( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + bn_norm_cfg=bn_norm_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + layer_scale_init_value=layer_scale_init_value, + )) + + for _ in range(cls_attn_layers): + self.ca_layers.append( + ClassAttentionBlock( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + tokens_norm=tokens_norm, + )) + + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims) + + # Transform out_indices + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + if frozen_stages > self.num_layers + 1: + raise ValueError('frozen_stages must be less than ' + f'{self.num_layers} but get {frozen_stages}') + self.frozen_stages = frozen_stages + + def init_weights(self): + super().init_weights() + + if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained': + return + + trunc_normal_(self.cls_token, std=.02) + + def _freeze_stages(self): + if self.frozen_stages < 0: + return + + # freeze position embedding + if self.use_pos_embed: + self.pos_embed.eval() + for param in self.pos_embed.parameters(): + param.requires_grad = False + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # set dropout to eval model + self.pos_drop.eval() + # freeze cls_token, only use in self.Clslayers + if self.frozen_stages > len(self.xca_layers): + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages): + if i <= len(self.xca_layers): + m = self.xca_layers[i - 1] + else: + m = self.ca_layers[i - len(self.xca_layers) - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze the last layer norm if all_stages are frozen + if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers): + self.norm.eval() + for param in self.norm.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + B = x.shape[0] + # x is (B, N, C). (Hp, Hw) is the patch resolution + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos_embed: + # (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp) + x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1) + x = self.pos_drop(x) + + for i, layer in enumerate(self.xca_layers): + x = layer(x, Hp, Wp) + if i in self.out_indices: + outs.append(self._format_output(x, (Hp, Wp), False)) + + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) + + for i, layer in enumerate(self.ca_layers): + x = layer(x) + if i == len(self.ca_layers) - 1: + x = self.norm(x) + if i + len(self.xca_layers) in self.out_indices: + outs.append(self._format_output(x, (Hp, Wp), True)) + + return tuple(outs) + + def _format_output(self, x, hw, with_cls_token: bool): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + if not with_cls_token: + raise ValueError( + 'Cannot output cls_token since there is no cls_token.') + return x[:, 0] + + patch_token = x[:, 1:] if with_cls_token else x + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/builder.py b/mmpretrain/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea4e25c8d6db3bbf07ab94ea08c08e474ec3595 --- /dev/null +++ b/mmpretrain/models/builder.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +CLASSIFIERS = MODELS +RETRIEVER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_classifier(cfg): + """Build classifier.""" + return CLASSIFIERS.build(cfg) + + +def build_retriever(cfg): + """Build retriever.""" + return RETRIEVER.build(cfg) diff --git a/mmpretrain/models/classifiers/__init__.py b/mmpretrain/models/classifiers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa276ff5a2152beb93c4d1b42e6bbf4e2cbf822 --- /dev/null +++ b/mmpretrain/models/classifiers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseClassifier +from .hugging_face import HuggingFaceClassifier +from .image import ImageClassifier +from .timm import TimmClassifier + +__all__ = [ + 'BaseClassifier', 'ImageClassifier', 'TimmClassifier', + 'HuggingFaceClassifier' +] diff --git a/mmpretrain/models/classifiers/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/classifiers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72fcb79de3557ec2132d697b925dc38c7b71c7e5 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/base.cpython-38.pyc b/mmpretrain/models/classifiers/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8dc4e81bb722149cbc6532c40548d81736d22c Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/base.cpython-38.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/hugging_face.cpython-38.pyc b/mmpretrain/models/classifiers/__pycache__/hugging_face.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56fbe4048ee5bbce3e5cd31d30bcc449fd7b33c Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/hugging_face.cpython-38.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/image.cpython-38.pyc b/mmpretrain/models/classifiers/__pycache__/image.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4119e7f8e972377b76b965b0c73d4d22ce7af80 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/image.cpython-38.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/timm.cpython-38.pyc b/mmpretrain/models/classifiers/__pycache__/timm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b062be56e54003776ee87358123fb850dc11f66 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/timm.cpython-38.pyc differ diff --git a/mmpretrain/models/classifiers/base.py b/mmpretrain/models/classifiers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a65fc213f4bfe271a9298b823ba38fc4ca9f57e1 --- /dev/null +++ b/mmpretrain/models/classifiers/base.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Sequence + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + + +class BaseClassifier(BaseModel, metaclass=ABCMeta): + """Base class for classifiers. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + + Attributes: + init_cfg (dict): Initialization config dict. + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__(self, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None): + super(BaseClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + @property + def with_neck(self) -> bool: + """Whether the classifier has a neck.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Whether the classifier has a head.""" + return hasattr(self, 'head') and self.head is not None + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`BaseDataElement`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) + in general. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmengine.BaseDataElement`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def extract_feats(self, multi_inputs: Sequence[torch.Tensor], + **kwargs) -> list: + """Extract features from a sequence of input tensor. + + Args: + multi_inputs (Sequence[torch.Tensor]): A sequence of input + tensor. It can be used in augmented inference. + **kwargs: Other keyword arguments accepted by :meth:`extract_feat`. + + Returns: + list: Features of every input tensor. + """ + assert isinstance(multi_inputs, Sequence), \ + '`extract_feats` is used for a sequence of inputs tensor. If you '\ + 'want to extract on single inputs tensor, use `extract_feat`.' + return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs] diff --git a/mmpretrain/models/classifiers/hugging_face.py b/mmpretrain/models/classifiers/hugging_face.py new file mode 100644 index 0000000000000000000000000000000000000000..26a8fda51b0d01ee54ba71665caedbb8a7bd842c --- /dev/null +++ b/mmpretrain/models/classifiers/hugging_face.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All right reserved. +import re +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import require +from .base import BaseClassifier + + +@MODELS.register_module() +class HuggingFaceClassifier(BaseClassifier): + """Image classifiers for HuggingFace model. + + This class accepts all positional and keyword arguments of the API + ``from_pretrained`` (when ``pretrained=True``) and ``from_config`` (when + ``pretrained=False``) of `transformers.AutoModelForImageClassification`_ + and use it to create a model from hugging-face. + + It can load checkpoints of hugging-face directly, and the saved checkpoints + also can be directly load by hugging-face. + + Please confirm that you have installed ``transfromers`` if you want to use it. + + .. _transformers.AutoModelForImageClassification: + https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification + + Args: + model_name (str): The name of the model to use in hugging-face. + pretrained (bool): Whether to load pretrained checkpoint from + hugging-face. Defaults to False. + *args: Other positional arguments of the method + `from_pretrained` or `from_config`. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + **kwargs: Other keyword arguments of the method + `from_pretrained` or `from_config`. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_classifier + >>> cfg = dict(type='HuggingFaceClassifier', model_name='microsoft/resnet-50', pretrained=True) + >>> model = build_classifier(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> out = model(inputs) + >>> print(out.shape) + torch.Size([1, 1000]) + """ # noqa: E501 + + @require('transformers') + def __init__(self, + model_name, + pretrained=False, + *model_args, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + train_cfg: Optional[dict] = None, + with_cp: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + from transformers import AutoConfig, AutoModelForImageClassification + if pretrained: + self.model = AutoModelForImageClassification.from_pretrained( + model_name, *model_args, **kwargs) + else: + config = AutoConfig.from_pretrained(model_name, *model_args, + **kwargs) + self.model = AutoModelForImageClassification.from_config(config) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + self.with_cp = with_cp + if self.with_cp: + self.model.gradient_checkpointing_enable() + + self._register_state_dict_hook(self._remove_state_dict_prefix) + self._register_load_state_dict_pre_hook(self._add_state_dict_prefix) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if mode == 'tensor': + return self.model(inputs).logits + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + raise NotImplementedError( + "The HuggingFaceClassifier doesn't support extract feature yet.") + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments of the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs).logits + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None): + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + + Returns: + List[DataSample]: The prediction results. + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs).logits + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + + return data_samples + + @staticmethod + def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = re.sub(f'^{prefix}model.', prefix, k) + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + new_prefix = prefix + 'model.' + for k in list(state_dict.keys()): + new_key = re.sub(f'^{prefix}', new_prefix, k) + state_dict[new_key] = state_dict[k] + del state_dict[k] diff --git a/mmpretrain/models/classifiers/image.py b/mmpretrain/models/classifiers/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0edd7aed8ce34a11b6cbbbdf2034bbcd1c652b --- /dev/null +++ b/mmpretrain/models/classifiers/image.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseClassifier + + +@MODELS.register_module() +class ImageClassifier(BaseClassifier): + """Image classifiers for supervised classification task. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in + :mod:`mmpretrain.model.utils.augment`. + - probs (List[float], optional): The probability of every batch + augmentation methods. If None, choose evenly. Defaults to None. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'ClsDataPreprocessor') + data_preprocessor.setdefault('batch_augments', train_cfg) + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super(ImageClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.backbone = backbone + self.neck = neck + self.head = head + + # If the model needs to load pretrain weights from a third party, + # the key can be modified with this hook + if hasattr(self.backbone, '_checkpoint_filter'): + self._register_load_state_dict_pre_hook( + self.backbone._checkpoint_filter) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor(s) without any + post-processing, same as a common PyTorch Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return self.head(feats) if self.with_head else feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs, stage='neck'): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + stage (str): Which stage to output the feature. Choose from: + + - "backbone": The output of backbone network. Returns a tuple + including multiple stages features. + - "neck": The output of neck module. Returns a tuple including + multiple stages features. + - "pre_logits": The feature before the final classification + linear layer. Usually returns a tensor. + + Defaults to "neck". + + Returns: + tuple | Tensor: The output of specified stage. + The output depends on detailed implementation. In general, the + output of backbone and neck is a tuple and the output of + pre_logits is a tensor. + + Examples: + 1. Backbone output + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 256, 14, 14]) + torch.Size([1, 512, 7, 7]) + + 2. Neck output + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64]) + torch.Size([1, 128]) + torch.Size([1, 256]) + torch.Size([1, 512]) + + 3. Pre-logits output (without the final linear classifier head) + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model + >>> model = build_classifier(cfg) + >>> + >>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + >>> print(out.shape) # The hidden dims in head is 3072 + torch.Size([1, 3072]) + """ # noqa: E501 + assert stage in ['backbone', 'neck', 'pre_logits'], \ + (f'Invalid output stage "{stage}", please choose from "backbone", ' + '"neck" and "pre_logits"') + + x = self.backbone(inputs) + + if stage == 'backbone': + return x + + if self.with_neck: + x = self.neck(x) + if stage == 'neck': + return x + + assert self.with_head and hasattr(self.head, 'pre_logits'), \ + "No head or the head doesn't implement `pre_logits` method." + return self.head.pre_logits(x) + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **kwargs) -> List[DataSample]: + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + feats = self.extract_feat(inputs) + return self.head.predict(feats, data_samples, **kwargs) + + def get_layer_depth(self, param_name: str): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + Tuple[int, int]: The layer-wise depth and the max depth. + """ + if hasattr(self.backbone, 'get_layer_depth'): + return self.backbone.get_layer_depth(param_name, 'backbone.') + else: + raise NotImplementedError( + f"The backbone {type(self.backbone)} doesn't " + 'support `get_layer_depth` by now.') diff --git a/mmpretrain/models/classifiers/timm.py b/mmpretrain/models/classifiers/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..d777b2e039d848b01fc9c6b6eaae6619bebb8938 --- /dev/null +++ b/mmpretrain/models/classifiers/timm.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All right reserved. +import re +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import require +from .base import BaseClassifier + + +@MODELS.register_module() +class TimmClassifier(BaseClassifier): + """Image classifiers for pytorch-image-models (timm) model. + + This class accepts all positional and keyword arguments of the function + `timm.models.create_model `_ and use + it to create a model from pytorch-image-models. + + It can load checkpoints of timm directly, and the saved checkpoints also + can be directly load by timm. + + Please confirm that you have installed ``timm`` if you want to use it. + + Args: + *args: All positional arguments of the function + `timm.models.create_model`. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + **kwargs: Other keyword arguments of the function + `timm.models.create_model`. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_classifier + >>> cfg = dict(type='TimmClassifier', model_name='resnet50', pretrained=True) + >>> model = build_classifier(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> out = model(inputs) + >>> print(out.shape) + torch.Size([1, 1000]) + """ # noqa: E501 + + @require('timm') + def __init__(self, + *args, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + train_cfg: Optional[dict] = None, + with_cp: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + from timm.models import create_model + self.model = create_model(*args, **kwargs) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + self.with_cp = with_cp + if self.with_cp: + self.model.set_grad_checkpointing() + + self._register_state_dict_hook(self._remove_state_dict_prefix) + self._register_load_state_dict_pre_hook(self._add_state_dict_prefix) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if mode == 'tensor': + return self.model(inputs) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + if hasattr(self.model, 'forward_features'): + return self.model.forward_features(inputs) + else: + raise NotImplementedError( + f"The model {type(self.model)} doesn't support extract " + "feature because it don't have `forward_features` method.") + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments of the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module(cls_score, target, **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None): + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + + Returns: + List[DataSample]: The prediction results. + """ + # The part can be traced by torch.fx + cls_score = self(inputs) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples=None): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + + return data_samples + + @staticmethod + def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = re.sub(f'^{prefix}model.', prefix, k) + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + new_prefix = prefix + 'model.' + for k in list(state_dict.keys()): + new_key = re.sub(f'^{prefix}', new_prefix, k) + state_dict[new_key] = state_dict[k] + del state_dict[k] diff --git a/mmpretrain/models/heads/__init__.py b/mmpretrain/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4364fb5626f321196952bc07bc2f54e3788a0ebe --- /dev/null +++ b/mmpretrain/models/heads/__init__.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beitv1_head import BEiTV1Head +from .beitv2_head import BEiTV2Head +from .cae_head import CAEHead +from .cls_head import ClsHead +from .conformer_head import ConformerHead +from .contrastive_head import ContrastiveHead +from .deit_head import DeiTClsHead +from .efficientformer_head import EfficientFormerClsHead +from .grounding_head import GroundingHead +from .itc_head import ITCHead +from .itm_head import ITMHead +from .itpn_clip_head import iTPNClipHead +from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead +from .levit_head import LeViTClsHead +from .linear_head import LinearClsHead +from .mae_head import MAEPretrainHead +from .margin_head import ArcFaceClsHead +from .mim_head import MIMHead +from .mixmim_head import MixMIMPretrainHead +from .mocov3_head import MoCoV3Head +from .multi_label_cls_head import MultiLabelClsHead +from .multi_label_csra_head import CSRAClsHead +from .multi_label_linear_head import MultiLabelLinearClsHead +from .multi_task_head import MultiTaskHead +from .seq_gen_head import SeqGenerationHead +from .simmim_head import SimMIMHead +from .spark_head import SparKPretrainHead +from .stacked_head import StackedLinearClsHead +from .swav_head import SwAVHead +from .vig_head import VigClsHead +from .vision_transformer_head import VisionTransformerClsHead +from .vqa_head import VQAGenerationHead + +__all__ = [ + 'ClsHead', + 'LinearClsHead', + 'StackedLinearClsHead', + 'MultiLabelClsHead', + 'MultiLabelLinearClsHead', + 'VisionTransformerClsHead', + 'DeiTClsHead', + 'ConformerHead', + 'EfficientFormerClsHead', + 'ArcFaceClsHead', + 'CSRAClsHead', + 'MultiTaskHead', + 'LeViTClsHead', + 'VigClsHead', + 'BEiTV1Head', + 'BEiTV2Head', + 'CAEHead', + 'ContrastiveHead', + 'LatentCrossCorrelationHead', + 'LatentPredictHead', + 'MAEPretrainHead', + 'MixMIMPretrainHead', + 'SwAVHead', + 'MoCoV3Head', + 'MIMHead', + 'SimMIMHead', + 'SeqGenerationHead', + 'VQAGenerationHead', + 'ITCHead', + 'ITMHead', + 'GroundingHead', + 'iTPNClipHead', + 'SparKPretrainHead', +] diff --git a/mmpretrain/models/heads/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94e813dee4a0c23414f30e4b96cbc896e3b530da Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/beitv1_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/beitv1_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60cedc2d95cdbfe7099c1491f314c980365acdc3 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/beitv1_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/beitv2_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/beitv2_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a812272289d214abce5a79bd5f1bd33d54197333 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/beitv2_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/cae_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/cae_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2657aeb71d69a8b51eb3a4122a682942c5685595 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/cae_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/cls_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/cls_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1984876c3d3d91b0b89cd99ffb14b70eef47e1d Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/cls_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/conformer_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/conformer_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6e5b5cd691270806c51686cb85745c337dd423e Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/conformer_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/contrastive_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/contrastive_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3541f4869bbd1d3d10a64eb1cce95abfc8c1ec5 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/contrastive_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/deit_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/deit_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b121d58a525a272cfb5dac82c9e6ce990450cf1 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/deit_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/efficientformer_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/efficientformer_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22499e8253f99ed0f16e85cac57955a96ffdd61e Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/efficientformer_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/grounding_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/grounding_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac3e342be4fe3db1589a1f48a3b2339707a53c5b Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/grounding_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/itc_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/itc_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4aa59f3a459b9ae2094a8063b7d91dffbb6c4d Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/itc_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/itm_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/itm_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bac96443c7e8207a0de1f0c0563d0948cd8d57d Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/itm_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/itpn_clip_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/itpn_clip_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec33d0aef5d64164204898eb8bef2fd4a739e8b8 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/itpn_clip_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/latent_heads.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/latent_heads.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d96ae379a2264cd21eb4bfcfcf108883de5d050f Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/latent_heads.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/levit_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/levit_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b34f3832e1bc511de3e05a68c7551d0bed07feb Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/levit_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/linear_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/linear_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40fc5df017a20d55975947624ae837742cfc5b13 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/linear_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mae_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/mae_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..991c108da6ebaf7866f2cc2c535dd2f0123027a7 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mae_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/margin_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/margin_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b25d23b0f2faa37efee6e7031d3595fd96c7226f Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/margin_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mim_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/mim_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b5d6c0aeb41e73f714dd80894c515794c77fa6d Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mim_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mixmim_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/mixmim_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8afb5b45430e735105b3a9f018bfe92383679a5a Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mixmim_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mocov3_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/mocov3_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea65ba600be30e59d3c8341bbeb24432177ef61 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mocov3_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_label_cls_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/multi_label_cls_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eea2b289dbe3f488e6ffecb58c54e0b6aec8452d Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_label_cls_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_label_csra_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/multi_label_csra_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a721c8f7df8e9d3e15dd62c5cf581ace1de21d76 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_label_csra_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_label_linear_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/multi_label_linear_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12c1a48ed5e5f373cba8c372b43d7bc2df519be6 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_label_linear_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_task_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/multi_task_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7c1c8447398201a1ffe223f61f093c778bcd014 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_task_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/seq_gen_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/seq_gen_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48a60306de0085a6502e295d0c59f5855f8342ad Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/seq_gen_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/simmim_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/simmim_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b330186b7a0e52938b5029f047b4a3fd839724c8 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/simmim_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/spark_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/spark_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7ca62a274ea64f52f4a968052ca506c26438772 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/spark_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/stacked_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/stacked_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97853ac33a9293c7e80a4c57c261949ec3cdae7f Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/stacked_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/swav_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/swav_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c013c0a4fbcf2913341f4f1cfa3a21a6260bd698 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/swav_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/vig_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/vig_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9862848785a6da5697101b8ae4e85e1d5f422e0b Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/vig_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/vision_transformer_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/vision_transformer_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05d877f284d480feca5f8e5a5591704da5f85ea5 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/vision_transformer_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/vqa_head.cpython-38.pyc b/mmpretrain/models/heads/__pycache__/vqa_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d0eb3e2ec670f328bc2f125e1b74f513cbbc618 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/vqa_head.cpython-38.pyc differ diff --git a/mmpretrain/models/heads/beitv1_head.py b/mmpretrain/models/heads/beitv1_head.py new file mode 100644 index 0000000000000000000000000000000000000000..df422ea71c9090d1ab084bbc93c8889a4f2f402e --- /dev/null +++ b/mmpretrain/models/heads/beitv1_head.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV1Head(BaseModule): + """Head for BEiT v1 Pre-training. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = torch.argmax(target, dim=1).flatten(1) + target = target[mask] + + # remove cls_token + feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss_module(logits, target) + return loss diff --git a/mmpretrain/models/heads/beitv2_head.py b/mmpretrain/models/heads/beitv2_head.py new file mode 100644 index 0000000000000000000000000000000000000000..cf677a2cf7c1a3964f1ba884a0ccae83f8b70a40 --- /dev/null +++ b/mmpretrain/models/heads/beitv2_head.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Head(BaseModule): + """Head for BEiT v2 Pre-training. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor, + target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + feats_cls_pt (torch.Tensor) : Features from class late layers for + pretraining. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # shared cls head + logits = self.cls_head(feats[mask]) + logits_cls_pt = self.cls_head(feats_cls_pt[mask]) + + loss_1 = self.loss_module(logits, target) + loss_2 = self.loss_module(logits_cls_pt, target) + return loss_1, loss_2 diff --git a/mmpretrain/models/heads/cae_head.py b/mmpretrain/models/heads/cae_head.py new file mode 100644 index 0000000000000000000000000000000000000000..18a07f0a79297c35a39b9b2da0d25bf1eac6e70b --- /dev/null +++ b/mmpretrain/models/heads/cae_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CAEHead(BaseModule): + """Head for CAE Pre-training. + + Compute the align loss and the main loss. In addition, this head also + generates the prediction target generated by dalle. + + Args: + loss (dict): The config of loss. + tokenizer_path (str): The path of the tokenizer. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + + @torch.no_grad() + def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: + """Generate the reconstruction target. + + Args: + logits_target (torch.Tensor): The logits generated by DALL-E.s + + Returns: + torch.Tensor: The logits target. + """ + target = torch.argmax(logits_target, dim=1) + return target.flatten(1) + + def loss(self, logits: torch.Tensor, logits_target: torch.Tensor, + latent_pred: torch.Tensor, latent_target: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate loss. + + Args: + logits (torch.Tensor): Logits generated by decoder. + logits_target (img_target): Target generated by dalle for decoder + prediction. + latent_pred (torch.Tensor): Latent prediction by regressor. + latent_target (torch.Tensor): Target for latent prediction, + generated by teacher. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. + - ``loss_main`` (torch.Tensor): Cross entropy loss. + - ``loss_align`` (torch.Tensor): MSE loss. + """ + + target = self._generate_target(logits_target) # target features + target = target[mask].detach() + + # loss main for decoder, loss align for regressor + loss_main, loss_align = self.loss_module(logits, target, latent_pred, + latent_target) + + return (loss_main, loss_align) diff --git a/mmpretrain/models/heads/cls_head.py b/mmpretrain/models/heads/cls_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac4c51804122adbc92df8c8748e4109e205110f --- /dev/null +++ b/mmpretrain/models/heads/cls_head.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.evaluation.metrics import Accuracy +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class ClsHead(BaseModule): + """Classification head. + + Args: + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + topk: Union[int, Tuple[int]] = (1, ), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ClsHead, self).__init__(init_cfg=init_cfg) + + self.topk = topk + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + self.cal_acc = cal_acc + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ClsHead``, we just obtain the feature + of the last stage. + """ + # The ClsHead doesn't have other module, just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The ClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate(cls_score, target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: Optional[List[Optional[DataSample]]] = None + ) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample | None], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples diff --git a/mmpretrain/models/heads/conformer_head.py b/mmpretrain/models/heads/conformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..eade90d567b5cb9189f62919ad9a6a0e9c47ae23 --- /dev/null +++ b/mmpretrain/models/heads/conformer_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.evaluation.metrics import Accuracy +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +@MODELS.register_module() +class ConformerHead(ClsHead): + """Linear classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (Sequence[int]): Number of channels in the input + feature map. + init_cfg (dict | optional): The extra init config of layers. + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__( + self, + num_classes: int, + in_channels: Sequence[int], # [conv_dim, trans_dim] + init_cfg: dict = dict(type='TruncNormal', layer='Linear', std=.02), + **kwargs): + super(ConformerHead, self).__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + self.init_cfg = init_cfg + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes) + self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes) + + def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ConformerHead``, we just obtain the + feature of the last stage. + """ + # The ConformerHead doesn't have other module, + # just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]: + """The forward process.""" + x = self.pre_logits(feats) + # There are two outputs in the Conformer model + assert len(x) == 2 + + conv_cls_score = self.conv_cls_head(x[0]) + tran_cls_score = self.trans_cls_head(x[1]) + + return conv_cls_score, tran_cls_score + + def predict(self, + feats: Tuple[List[torch.Tensor]], + data_samples: List[DataSample] = None) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + conv_cls_score, tran_cls_score = self(feats) + cls_score = conv_cls_score + tran_cls_score + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_loss(self, cls_score: Tuple[torch.Tensor], + data_samples: List[DataSample], **kwargs) -> dict: + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = sum([ + self.loss_module( + score, target, avg_factor=score.size(0), **kwargs) + for score in cls_score + ]) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate( + cls_score[0] + cls_score[1], target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses diff --git a/mmpretrain/models/heads/contrastive_head.py b/mmpretrain/models/heads/contrastive_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6d1474aed59e2912ca4b5c24ce5a2430f50cb913 --- /dev/null +++ b/mmpretrain/models/heads/contrastive_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class ContrastiveHead(BaseModule): + """Head for contrastive learning. + + The contrastive loss is implemented in this head and is used in SimCLR, + MoCo, DenseCL, etc. + + Args: + loss (dict): Config dict for module of loss functions. + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Defaults to 0.1. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + temperature: float = 0.1, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + self.temperature = temperature + + def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor: + """Forward function to compute contrastive loss. + + Args: + pos (torch.Tensor): Nx1 positive similarity. + neg (torch.Tensor): Nxk negative similarity. + + Returns: + torch.Tensor: The contrastive loss. + """ + N = pos.size(0) + logits = torch.cat((pos, neg), dim=1) + logits /= self.temperature + labels = torch.zeros((N, ), dtype=torch.long).to(pos.device) + + loss = self.loss_module(logits, labels) + return loss diff --git a/mmpretrain/models/heads/deit_head.py b/mmpretrain/models/heads/deit_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a96f6e152711d23646e02312218c0c85e96300e8 --- /dev/null +++ b/mmpretrain/models/heads/deit_head.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .vision_transformer_head import VisionTransformerClsHead + + +@MODELS.register_module() +class DeiTClsHead(VisionTransformerClsHead): + """Distilled Vision Transformer classifier head. + + Comparing with the :class:`VisionTransformerClsHead`, this head adds an + extra linear layer to handle the dist token. The final classification score + is the average of both linear transformation results of ``cls_token`` and + ``dist_token``. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def _init_layers(self): + """"Init extra hidden linear layer to handle dist token if exists.""" + super(DeiTClsHead, self)._init_layers() + if self.hidden_dim is None: + head_dist = nn.Linear(self.in_channels, self.num_classes) + else: + head_dist = nn.Linear(self.hidden_dim, self.num_classes) + self.layers.add_module('head_dist', head_dist) + + def pre_logits(self, + feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``DeiTClsHead``, we obtain the + feature of the last stage and forward in hidden layer if exists. + """ + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + if len(feat) == 3: + _, cls_token, dist_token = feat + else: + cls_token, dist_token = feat + if self.hidden_dim is None: + return cls_token, dist_token + else: + cls_token = self.layers.act(self.layers.pre_logits(cls_token)) + dist_token = self.layers.act(self.layers.pre_logits(dist_token)) + return cls_token, dist_token + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + if self.training: + warnings.warn('MMPretrain cannot train the ' + 'distilled version DeiT.') + cls_token, dist_token = self.pre_logits(feats) + # The final classification head. + cls_score = (self.layers.head(cls_token) + + self.layers.head_dist(dist_token)) / 2 + return cls_score diff --git a/mmpretrain/models/heads/efficientformer_head.py b/mmpretrain/models/heads/efficientformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..09aa05b28533028723f599881777939a48982319 --- /dev/null +++ b/mmpretrain/models/heads/efficientformer_head.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +@MODELS.register_module() +class EfficientFormerClsHead(ClsHead): + """EfficientFormer classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + distillation (bool): Whether use a additional distilled head. + Defaults to True. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__(self, + num_classes, + in_channels, + distillation=True, + init_cfg=dict(type='Normal', layer='Linear', std=0.01), + *args, + **kwargs): + super(EfficientFormerClsHead, self).__init__( + init_cfg=init_cfg, *args, **kwargs) + self.in_channels = in_channels + self.num_classes = num_classes + self.dist = distillation + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.head = nn.Linear(self.in_channels, self.num_classes) + if self.dist: + self.dist_head = nn.Linear(self.in_channels, self.num_classes) + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.head(pre_logits) + + if self.dist: + cls_score = (cls_score + self.dist_head(pre_logits)) / 2 + return cls_score + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In :obj`EfficientFormerClsHead`, we just + obtain the feature of the last stage. + """ + # The EfficientFormerClsHead doesn't have other module, just return + # after unpacking. + return feats[-1] + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if self.dist: + raise NotImplementedError( + "MMPretrain doesn't support to train" + ' the distilled version EfficientFormer.') + else: + return super().loss(feats, data_samples, **kwargs) diff --git a/mmpretrain/models/heads/grounding_head.py b/mmpretrain/models/heads/grounding_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a47512ef5930dde51a7023a07c3412d759b6bd8c --- /dev/null +++ b/mmpretrain/models/heads/grounding_head.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models.utils.box_utils import (box_cxcywh_to_xyxy, + generalized_box_iou) +from mmpretrain.registry import MODELS, TOKENIZER + + +@MODELS.register_module() +class GroundingHead(BaseModule): + """bbox Coordination generation head for multi-modal pre-trained task, + adapted by BLIP. Normally used for visual grounding. + + Args: + loss: dict, + decoder: dict, + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + decoder: dict = None, + tokenizer: dict = None, + box_l1_loss_coeff=4.0, + box_giou_loss_coeff=2.0, + init_cfg: Optional[dict] = None, + ) -> None: + super(GroundingHead, self).__init__(init_cfg=init_cfg) + ''' init the decoder from med_config''' + self.decoder = None + if decoder: + self.decoder = MODELS.build(decoder) + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='none', ignore_index=-100) + + self.box_l1_loss_coeff = box_l1_loss_coeff + self.box_giou_loss_coeff = box_giou_loss_coeff + + if isinstance(tokenizer, dict): + self.tokenizer = TOKENIZER.build(tokenizer) + else: + self.tokenizer = tokenizer + + self.image_res = 640 + prefix_ids = torch.tensor( + self.tokenizer.convert_tokens_to_ids(['[unused339]'])) + target_ids = torch.tensor( + self.tokenizer.convert_tokens_to_ids( + [f'[unused{340+_}]' for _ in range(self.image_res + 1)])) + self.register_buffer('prefix_ids', prefix_ids) + self.register_buffer('target_ids', target_ids) + + bbox_prob_mask = torch.zeros(len(self.tokenizer)) + bbox_prob_mask[self.target_ids[0]:self.target_ids[-1] + 1] = 1 + bbox_prob_mask = (1.0 - bbox_prob_mask) * -10000.0 + self.register_buffer('bbox_prob_mask', bbox_prob_mask) + self.bin_start_idx = self.target_ids[0] + + def forward(self, text_embedding, text_embedding_mask, + encoder_hidden_states, encoder_attention_mask): + + # localize prompt token, text embedding + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + loc_prompt = self.prompt.weight.T + loc_prompt = torch.repeat_interleave(loc_prompt, + merge_att_mask.shape[0], + 0).unsqueeze(1) + + loc_prompt_mask = torch.ones(loc_prompt.shape[:-1]).long().to( + loc_prompt.device) + + decoder_out = self.decoder( + inputs_embeds=loc_prompt, + attention_mask=loc_prompt_mask, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + output_hidden_states=True, + labels=None, + ) + decoder_hs = decoder_out.hidden_states[-1][:, 0, :] + box_pred = self.box_head(decoder_hs) + return decoder_out, decoder_hs, box_pred + + def loss(self, + text_embedding, + text_embedding_mask, + encoder_hidden_states, + encoder_attention_mask, + decoder_targets, + return_scores=False): + """Calculate losses from the extracted features. + + Args: + feats (dict): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + answer_targets = (decoder_targets * + self.image_res).long() + self.bin_start_idx + prefix_ids = torch.repeat_interleave(self.prefix_ids, + merge_att_mask.shape[0], + 0).unsqueeze(-1) + prefix_ids = torch.cat([prefix_ids, answer_targets], dim=1) + + answer_output = self.decoder( + prefix_ids, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + labels=None, + return_dict=True, + ) + prob_mask = self.bbox_prob_mask.view(1, 1, + self.bbox_prob_mask.shape[-1]) + prediction_scores = answer_output.logits + prob_mask + + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = prefix_ids[:, 1:].contiguous() + vocab_size = len(self.tokenizer) + loss_seq_init = self.loss_fn( + shifted_prediction_scores.view(-1, vocab_size), labels.view(-1)) + + with torch.no_grad(): + pred_box = (torch.argmax( + prediction_scores[:, :-1, :].contiguous(), dim=-1) - + self.bin_start_idx) / self.image_res + weight_bbox = F.l1_loss( + pred_box, decoder_targets, reduction='none').clamp( + 0, 5) * self.box_l1_loss_coeff + weight_giou = (1 - torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(pred_box), + box_cxcywh_to_xyxy(decoder_targets))) + ) * self.box_giou_loss_coeff + bs = text_embedding.shape[0] + loss_seq = loss_seq_init[:].view(bs, -1, 4) + loss_seq = loss_seq * weight_bbox + loss_seq = loss_seq * weight_giou.unsqueeze(1) + + loss_seq = loss_seq.mean() + + losses = { + 'loss_seq': loss_seq, + 'loss_seq_init': loss_seq_init.mean(), + 'loss': loss_seq, + 'box_l1': weight_bbox.mean(-1).mean().detach(), + 'box_giou': weight_giou.mean().detach() + } + + return losses + + def predict( + self, + text_embedding, + text_embedding_mask, + encoder_hidden_states, + encoder_attention_mask, + ): + """Generates the bbox coordinates at inference time.""" + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + prefix_ids = torch.repeat_interleave(self.prefix_ids, + merge_att_mask.shape[0], + 0).unsqueeze(-1) + + for _ in range(4): + decoder_output = self.decoder( + prefix_ids, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + labels=None, + return_dict=True, + ) + prob_mask = self.bbox_prob_mask.view(1, 1, + self.bbox_prob_mask.shape[-1]) + prediction_scores = decoder_output.logits + prob_mask + + prefix_ids = torch.cat([ + prefix_ids, + torch.argmax(prediction_scores[:, -1, :], dim=-1).unsqueeze(1) + ], + dim=1) + + pred_box = self.process_bbox(prefix_ids[:, 1:]) # xywh 0-1 to xyxy 0-1 + + return pred_box + + @torch.no_grad() + def process_bbox(self, bbox): + bbox = bbox - self.bin_start_idx + bbox = torch.true_divide(bbox, self.image_res) + bbox = box_cxcywh_to_xyxy(bbox) + bbox = torch.clip(bbox, 0, 1) + assert torch.all(bbox <= 1) + return bbox diff --git a/mmpretrain/models/heads/itc_head.py b/mmpretrain/models/heads/itc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..006d52c76d9317809c7bb07519f4efb18716d8bd --- /dev/null +++ b/mmpretrain/models/heads/itc_head.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.dist import all_gather +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class ITCHead(BaseModule): + """Image-text matching head for multi-modal pre-trained task. Adapted by + BLIP, ALBEF. Normally used for retrieval task. + + Args: + embed_dim (int): Embed channel size for queue. + queue_size (int): Queue size for image and text. Defaults to 57600. + temperature (float): Temperature to calculate the similarity. + Defaults to 0.07. + use_distill (bool): Whether to use distill to calculate loss. + Defaults to True. + alpha (float): Weight for momentum similarity. Defaults to 0.4. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + embed_dim: int, + queue_size: int = 57600, + temperature: float = 0.07, + use_distill: bool = True, + alpha: float = 0.4, + init_cfg: Optional[dict] = None): + super(ITCHead, self).__init__(init_cfg=init_cfg) + self.temp = nn.Parameter(temperature * torch.ones([])) + self.use_distill = use_distill + if self.use_distill: + # create the queue + self.register_buffer('image_queue', + torch.randn(embed_dim, queue_size)) + self.register_buffer('text_queue', + torch.randn(embed_dim, queue_size)) + self.register_buffer('idx_queue', torch.full((1, queue_size), + -100)) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + self.image_queue = F.normalize(self.image_queue, dim=0) + self.text_queue = F.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + # This value will be warmup by `WarmupParamHook` + self.alpha = alpha + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + return feats[-1] + + def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + # The part can be traced by torch.fx + img_feats, text_feats, img_feats_m, text_feats_m = self(feats) + + img_feats_all = torch.cat( + [img_feats_m.t(), + self.image_queue.clone().detach()], dim=1) + text_feats_all = torch.cat( + [text_feats_m.t(), + self.text_queue.clone().detach()], dim=1) + + # The part can not be traced by torch.fx + losses = self._get_loss(img_feats, text_feats, img_feats_m, + text_feats_m, img_feats_all, text_feats_all, + data_samples, **kwargs) + return losses + + def _get_loss(self, img_feats, text_feats, img_feats_m, text_feats_m, + img_feats_all, text_feats_all, data_samples, **kwargs): + """Unpack data samples and compute loss.""" + + idx = torch.tensor([ds.image_id + for ds in data_samples]).to(img_feats.device) + idx = idx.view(-1, 1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + with torch.no_grad(): + if self.use_distill: + sim_i2t_m = img_feats_m @ text_feats_all / self.temp + sim_t2i_m = text_feats_m @ img_feats_all / self.temp + + sim_i2t_targets = ( + self.alpha * F.softmax(sim_i2t_m, dim=1) + + (1 - self.alpha) * sim_targets) + sim_t2i_targets = ( + self.alpha * F.softmax(sim_t2i_m, dim=1) + + (1 - self.alpha) * sim_targets) + + sim_i2t = img_feats @ text_feats_all / self.temp + sim_t2i = text_feats @ img_feats_all / self.temp + + if self.use_distill: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean() + else: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() + + # compute loss + losses = dict() + + losses['itc_loss'] = (loss_i2t + loss_t2i) / 2 + self._dequeue_and_enqueue(img_feats_m, text_feats_m, idx) + return losses + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = torch.cat(all_gather(image_feat)) + text_feats = torch.cat(all_gather(text_feat)) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = torch.cat(all_gather(idxs)) + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr diff --git a/mmpretrain/models/heads/itm_head.py b/mmpretrain/models/heads/itm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b42f3f684e2ffefd085b39360706a339017f4c --- /dev/null +++ b/mmpretrain/models/heads/itm_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.evaluation import Accuracy +from mmpretrain.registry import MODELS + + +class Pooler(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@MODELS.register_module() +class ITMHead(BaseModule): + """Image-text matching head for multi-modal pre-trained task. Adapted by + BLIP, FLAVA. + + Args: + hidden_size (int): Hidden channel size out input features. + with_pooler (bool): Whether a pooler is added. Defaults to True. + loss (dict): Config of global contrasive loss. Defaults to + ``dict(type='GlobalContrasiveLoss')``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + hidden_size: int, + with_pooler: bool = True, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ITMHead, self).__init__(init_cfg=init_cfg) + self.hidden_size = hidden_size + + if with_pooler: + self.pooler = Pooler(hidden_size=self.hidden_size) + else: + self.pooler = nn.Identity() + self.fc = nn.Linear(self.hidden_size, 2) + + self.loss_module = MODELS.build(loss) + self.cal_acc = cal_acc + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pooler(feats[-1]) + itm_logits = self.fc(pre_logits) + return itm_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + # The part can be traced by torch.fx + itm_logits = self(feats) + + # deal with query + if itm_logits.ndim == 3: + itm_logits = itm_logits.mean(dim=1) + + # The part can not be traced by torch.fx + losses = self._get_loss(itm_logits, data_samples, **kwargs) + return losses + + def _get_loss(self, itm_logits: torch.Tensor, data_samples, **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + # use `itm_label` in here temporarily + target = torch.tensor([i.is_matched + for i in data_samples]).to(itm_logits.device) + + # compute loss + losses = dict() + + loss = self.loss_module( + itm_logits, target.long(), avg_factor=itm_logits.size(0), **kwargs) + losses['itm_loss'] = loss + + # compute accuracy + if self.cal_acc: + # topk is meaningless for matching task + acc = Accuracy.calculate(itm_logits, target) + # acc is warpped with two lists of topk and thrs + # which are unnecessary here + losses.update({'itm_accuracy': acc[0][0]}) + + return losses diff --git a/mmpretrain/models/heads/itpn_clip_head.py b/mmpretrain/models/heads/itpn_clip_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7465d7c2a8924d93afc4e9f5a461bcea49880aee --- /dev/null +++ b/mmpretrain/models/heads/itpn_clip_head.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class iTPNClipHead(BaseModule): + """Head for iTPN Pre-training using Clip. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + + mask = mask.to(torch.device('cuda'), non_blocking=True) + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # remove cls_token + # feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss_module(logits, target) + return loss diff --git a/mmpretrain/models/heads/latent_heads.py b/mmpretrain/models/heads/latent_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..a9662b5d91c8534d1a2a7834e4b9e3ec37f552c1 --- /dev/null +++ b/mmpretrain/models/heads/latent_heads.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_reduce, get_world_size +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class LatentPredictHead(BaseModule): + """Head for latent feature prediction. + + This head builds a predictor, which can be any registered neck component. + For example, BYOL and SimSiam call this head and build NonLinearNeck. + It also implements similarity loss between two forward features. + + Args: + loss (dict): Config dict for the loss. + predictor (dict): Config dict for the predictor. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + predictor: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + self.predictor = MODELS.build(predictor) + + def loss(self, input: torch.Tensor, + target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward head. + + Args: + input (torch.Tensor): NxC input features. + target (torch.Tensor): NxC target features. + + Returns: + torch.Tensor: The latent predict loss. + """ + pred = self.predictor([input])[0] + target = target.detach() + + loss = self.loss_module(pred, target) + + return loss + + +@MODELS.register_module() +class LatentCrossCorrelationHead(BaseModule): + """Head for latent feature cross correlation. + + Part of the code is borrowed from `script + `_. + + Args: + in_channels (int): Number of input channels. + loss (dict): Config dict for module of loss functions. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.world_size = get_world_size() + self.bn = nn.BatchNorm1d(in_channels, affine=False) + self.loss_module = MODELS.build(loss) + + def loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Forward head. + + Args: + input (torch.Tensor): NxC input features. + target (torch.Tensor): NxC target features. + + Returns: + torch.Tensor: The cross correlation loss. + """ + # cross-correlation matrix + cross_correlation_matrix = self.bn(input).T @ self.bn(target) + cross_correlation_matrix.div_(input.size(0) * self.world_size) + + all_reduce(cross_correlation_matrix) + + loss = self.loss_module(cross_correlation_matrix) + return loss diff --git a/mmpretrain/models/heads/levit_head.py b/mmpretrain/models/heads/levit_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a74d7ecc52caca0adca642e528f2861f9a0e5833 --- /dev/null +++ b/mmpretrain/models/heads/levit_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.models.heads import ClsHead +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +class BatchNormLinear(BaseModule): + + def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')): + super(BatchNormLinear, self).__init__() + self.bn = build_norm_layer(norm_cfg, in_channels) + self.linear = nn.Linear(in_channels, out_channels) + + @torch.no_grad() + def fuse(self): + w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 + b = self.bn.bias - self.bn.running_mean * \ + self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5 + w = self.linear.weight * w[None, :] + b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias + + self.linear.weight.data.copy_(w) + self.linear.bias.data.copy_(b) + return self.linear + + def forward(self, x): + x = self.bn(x) + x = self.linear(x) + return x + + +def fuse_parameters(module): + for child_name, child in module.named_children(): + if hasattr(child, 'fuse'): + setattr(module, child_name, child.fuse()) + else: + fuse_parameters(child) + + +@MODELS.register_module() +class LeViTClsHead(ClsHead): + + def __init__(self, + num_classes=1000, + distillation=True, + in_channels=None, + deploy=False, + **kwargs): + super(LeViTClsHead, self).__init__(**kwargs) + self.num_classes = num_classes + self.distillation = distillation + self.deploy = deploy + self.head = BatchNormLinear(in_channels, num_classes) + if distillation: + self.head_dist = BatchNormLinear(in_channels, num_classes) + + if self.deploy: + self.switch_to_deploy(self) + + def switch_to_deploy(self): + if self.deploy: + return + fuse_parameters(self) + self.deploy = True + + def forward(self, x): + x = self.pre_logits(x) + if self.distillation: + x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000 + if not self.training: + x = (x[0] + x[1]) / 2 + else: + raise NotImplementedError("MMPretrain doesn't support " + 'training in distillation mode.') + else: + x = self.head(x) + return x diff --git a/mmpretrain/models/heads/linear_head.py b/mmpretrain/models/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..90b4c2b11eb0b2ba087fd438a32596cedb13cebb --- /dev/null +++ b/mmpretrain/models/heads/linear_head.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class LinearClsHead(ClsHead): + """Linear classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01), + **kwargs): + super(LinearClsHead, self).__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``LinearClsHead``, we just obtain the + feature of the last stage. + """ + # The LinearClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5366d13b5f5bed0baedea06b9ff956ff5cf16b --- /dev/null +++ b/mmpretrain/models/heads/mae_head.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MAEPretrainHead(BaseModule): + """Head for MAE Pre-training. + + Args: + loss (dict): Config of loss. + norm_pix_loss (bool): Whether or not normalize target. + Defaults to False. + patch_size (int): Patch size. Defaults to 16. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.loss_module = MODELS.build(loss) + + def patchify(self, imgs: torch.Tensor) -> torch.Tensor: + r"""Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images. The shape should + be :math:`(B, 3, H, W)`. + + Returns: + torch.Tensor: Patchified images. The shape is + :math:`(B, L, \text{patch_size}^2 \times 3)`. + """ + p = self.patch_size + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + return x + + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + r"""Combine non-overlapped patches into images. + + Args: + x (torch.Tensor): The shape is + :math:`(B, L, \text{patch_size}^2 \times 3)`. + + Returns: + torch.Tensor: The shape is :math:`(B, 3, H, W)`. + """ + p = self.patch_size + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + + Args: + target (torch.Tensor): Image with the shape of B x 3 x H x W + + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + loss = self.loss_module(pred, target, mask) + + return loss diff --git a/mmpretrain/models/heads/margin_head.py b/mmpretrain/models/heads/margin_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3a88bf8b3f4d19b233192a7578f49b750ff53ed5 --- /dev/null +++ b/mmpretrain/models/heads/margin_head.py @@ -0,0 +1,300 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.fileio import list_from_file +from mmengine.runner import autocast +from mmengine.utils import is_seq_of + +from mmpretrain.models.losses import convert_to_one_hot +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +class NormProduct(nn.Linear): + """An enhanced linear layer with k clustering centers to calculate product + between normalized input and linear weight. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample + k (int): The number of clustering centers. Defaults to 1. + bias (bool): Whether there is bias. If set to ``False``, the + layer will not learn an additive bias. Defaults to ``True``. + feature_norm (bool): Whether to normalize the input feature. + Defaults to ``True``. + weight_norm (bool):Whether to normalize the weight. + Defaults to ``True``. + """ + + def __init__(self, + in_features: int, + out_features: int, + k=1, + bias: bool = False, + feature_norm: bool = True, + weight_norm: bool = True): + + super().__init__(in_features, out_features * k, bias=bias) + self.weight_norm = weight_norm + self.feature_norm = feature_norm + self.out_features = out_features + self.k = k + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.feature_norm: + input = F.normalize(input) + if self.weight_norm: + weight = F.normalize(self.weight) + else: + weight = self.weight + cosine_all = F.linear(input, weight, self.bias) + + if self.k == 1: + return cosine_all + else: + cosine_all = cosine_all.view(-1, self.out_features, self.k) + cosine, _ = torch.max(cosine_all, dim=2) + return cosine + + +@MODELS.register_module() +class ArcFaceClsHead(ClsHead): + """ArcFace classifier head. + + A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss + for Deep Face Recognition `_ and + `Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web + Faces `_ + + Example: + To use ArcFace in config files. + + 1. use vanilla ArcFace + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 2. use SubCenterArcFace with 3 sub-centers + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 3. use SubCenterArcFace With CountPowerAdaptiveMargins + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + custom_hooks = [dict(type='SetAdaptiveMarginsHook')] + + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_subcenters (int): Number of subcenters. Defaults to 1. + scale (float): Scale factor of output logit. Defaults to 64.0. + margins (float): The penalty margin. Could be the fllowing formats: + + - float: The margin, would be same for all the categories. + - Sequence[float]: The category-based margins list. + - str: A '.txt' file path which contains a list. Each line + represents the margin of a category, and the number in the + i-th row indicates the margin of the i-th class. + + Defaults to 0.5. + easy_margin (bool): Avoid theta + m >= PI. Defaults to False. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + num_subcenters: int = 1, + scale: float = 64., + margins: Optional[Union[float, Sequence[float], str]] = 0.50, + easy_margin: bool = False, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg: Optional[dict] = None): + + super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + assert num_subcenters >= 1 and num_classes >= 0 + self.in_channels = in_channels + self.num_classes = num_classes + self.num_subcenters = num_subcenters + self.scale = scale + self.easy_margin = easy_margin + + self.norm_product = NormProduct(in_channels, num_classes, + num_subcenters) + + if isinstance(margins, float): + margins = [margins] * num_classes + elif isinstance(margins, str) and margins.endswith('.txt'): + margins = [float(item) for item in list_from_file(margins)] + else: + assert is_seq_of(list(margins), (float, int)), ( + 'the attribute `margins` in ``ArcFaceClsHead`` should be a ' + ' float, a Sequence of float, or a ".txt" file path.') + + assert len(margins) == num_classes, \ + 'The length of margins must be equal with num_classes.' + + self.register_buffer( + 'margins', torch.tensor(margins).float(), persistent=False) + # To make `phi` monotonic decreasing, refers to + # https://github.com/deepinsight/insightface/issues/108 + sinm_m = torch.sin(math.pi - self.margins) * self.margins + threshold = torch.cos(math.pi - self.margins) + self.register_buffer('sinm_m', sinm_m, persistent=False) + self.register_buffer('threshold', threshold, persistent=False) + + def set_margins(self, margins: Union[Sequence[float], float]) -> None: + """set margins of arcface head. + + Args: + margins (Union[Sequence[float], float]): The marigins. + """ + if isinstance(margins, float): + margins = [margins] * self.num_classes + assert is_seq_of( + list(margins), float) and (len(margins) == self.num_classes), ( + f'margins must be Sequence[Union(float, int)], get {margins}') + + self.margins = torch.tensor( + margins, device=self.margins.device, dtype=torch.float32) + self.sinm_m = torch.sin(self.margins) * self.margins + self.threshold = -torch.cos(self.margins) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ArcFaceHead``, we just obtain the + feature of the last stage. + """ + # The ArcFaceHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def _get_logit_with_margin(self, pre_logits, target): + """add arc margin to the cosine in target index. + + The target must be in index format. + """ + assert target.dim() == 1 or ( + target.dim() == 2 and target.shape[1] == 1), \ + 'The target must be in index format.' + cosine = self.norm_product(pre_logits) + phi = torch.cos(torch.acos(cosine) + self.margins) + + if self.easy_margin: + # when cosine>0, choose phi + # when cosine<=0, choose cosine + phi = torch.where(cosine > 0, phi, cosine) + else: + # when cos>th, choose phi + # when cos<=th, choose cosine-mm + phi = torch.where(cosine > self.threshold, phi, + cosine - self.sinm_m) + + target = convert_to_one_hot(target, self.num_classes) + output = target * phi + (1 - target) * cosine + return output + + def forward(self, + feats: Tuple[torch.Tensor], + target: Optional[torch.Tensor] = None) -> torch.Tensor: + """The forward process.""" + # Disable AMP + with autocast(enabled=False): + pre_logits = self.pre_logits(feats) + + if target is None: + # when eval, logit is the cosine between W and pre_logits; + # cos(theta_yj) = (x/||x||) * (W/||W||) + logit = self.norm_product(pre_logits) + else: + # when training, add a margin to the pre_logits where target is + # True, then logit is the cosine between W and new pre_logits + logit = self._get_logit_with_margin(pre_logits, target) + + return self.scale * logit + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # Unpack data samples and pack targets + label_target = torch.cat([i.gt_label for i in data_samples]) + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = label_target + + # the index format target would be used + cls_score = self(feats, label_target) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses diff --git a/mmpretrain/models/heads/mim_head.py b/mmpretrain/models/heads/mim_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bda90c8198986ec9b2ff2d03db3350e1f1a25823 --- /dev/null +++ b/mmpretrain/models/heads/mim_head.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MIMHead(BaseModule): + """Pre-training head for Masked Image Modeling. + + Args: + loss (dict): Config dict for module of loss functions. + """ + + def __init__(self, loss: dict) -> None: + super().__init__() + self.loss_module = MODELS.build(loss) + + def loss(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward head. + + Args: + pred (torch.Tensor): Predictions with shape B x L x C. + target (torch.Tensor): Targets with shape B x L x C. + mask (torch.Tensor): Mask with shape B x L. + + Returns: + torch.Tensor: The loss tensor. + """ + loss = self.loss_module(pred, target, mask) + return loss diff --git a/mmpretrain/models/heads/mixmim_head.py b/mmpretrain/models/heads/mixmim_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a709630abb26bce1153596cec842da0912bab912 --- /dev/null +++ b/mmpretrain/models/heads/mixmim_head.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS +from .mae_head import MAEPretrainHead + + +@MODELS.register_module() +class MixMIMPretrainHead(MAEPretrainHead): + """Head for MixMIM Pre-training. + + Args: + loss (dict): Config of loss. + norm_pix_loss (bool): Whether or not normalize target. + Defaults to False. + patch_size (int): Patch size. Defaults to 16. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16) -> None: + super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size) + + def loss(self, x_rec: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + + B, L, C = x_rec.shape + + # unmix tokens + x1_rec = x_rec[:B // 2] + x2_rec = x_rec[B // 2:] + + unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask) + + loss_rec = self.loss_module(unmix_x_rec, target) + + return loss_rec diff --git a/mmpretrain/models/heads/mocov3_head.py b/mmpretrain/models/heads/mocov3_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bec2a6cc90247fab44d6d954a8a0c6ede0a812 --- /dev/null +++ b/mmpretrain/models/heads/mocov3_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.dist import all_gather, get_rank +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV3Head(BaseModule): + """Head for MoCo v3 Pre-training. + + This head builds a predictor, which can be any registered neck component. + It also implements latent contrastive loss between two forward features. + Part of the code is modified from: + ``_. + + Args: + predictor (dict): Config dict for module of predictor. + loss (dict): Config dict for module of loss functions. + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Defaults to 1.0. + """ + + def __init__(self, + predictor: dict, + loss: dict, + temperature: float = 1.0) -> None: + super().__init__() + self.predictor = MODELS.build(predictor) + self.loss_module = MODELS.build(loss) + self.temperature = temperature + + def loss(self, base_out: torch.Tensor, + momentum_out: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + base_out (torch.Tensor): NxC features from base_encoder. + momentum_out (torch.Tensor): NxC features from momentum_encoder. + + Returns: + torch.Tensor: The loss tensor. + """ + # predictor computation + pred = self.predictor([base_out])[0] + + # normalize + pred = nn.functional.normalize(pred, dim=1) + target = nn.functional.normalize(momentum_out, dim=1) + + # get negative samples + target = torch.cat(all_gather(target), dim=0) + + # Einstein sum is more intuitive + logits = torch.einsum('nc,mc->nm', [pred, target]) / self.temperature + + # generate labels + batch_size = logits.shape[0] + labels = (torch.arange(batch_size, dtype=torch.long) + + batch_size * get_rank()).to(logits.device) + + loss = self.loss_module(logits, labels) + return loss diff --git a/mmpretrain/models/heads/multi_label_cls_head.py b/mmpretrain/models/heads/multi_label_cls_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ca36bfe06e70e1e0f16a5dc4c161b186234f57ac --- /dev/null +++ b/mmpretrain/models/heads/multi_label_cls_head.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample, label_to_onehot + + +@MODELS.register_module() +class MultiLabelClsHead(BaseModule): + """Classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = None): + super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + if thr is None and topk is None: + thr = 0.5 + + self.thr = thr + self.topk = topk + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain + the feature of the last stage. + """ + # The MultiLabelClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The MultiLabelClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + num_classes = cls_score.size()[-1] + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + target = torch.stack([i.gt_score.float() for i in data_samples]) + else: + target = torch.stack([ + label_to_onehot(i.gt_label, num_classes) for i in data_samples + ]).float() + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + feats: Tuple[torch.Tensor], + data_samples: List[DataSample] = None) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score: torch.Tensor, + data_samples: List[DataSample]): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = torch.sigmoid(cls_score) + + if data_samples is None: + data_samples = [DataSample() for _ in range(cls_score.size(0))] + + for data_sample, score in zip(data_samples, pred_scores): + if self.thr is not None: + # a label is predicted positive if larger than thr + label = torch.where(score >= self.thr)[0] + else: + # top-k labels will be predicted positive for any example + _, label = score.topk(self.topk) + data_sample.set_pred_score(score).set_pred_label(label) + + return data_samples diff --git a/mmpretrain/models/heads/multi_label_csra_head.py b/mmpretrain/models/heads/multi_label_csra_head.py new file mode 100644 index 0000000000000000000000000000000000000000..95a3a0e8b9d6c68c2f2c1da3c0c160c4c695cc7c --- /dev/null +++ b/mmpretrain/models/heads/multi_label_csra_head.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/Kevinz-code/CSRA +from typing import Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from .multi_label_cls_head import MultiLabelClsHead + + +@MODELS.register_module() +class CSRAClsHead(MultiLabelClsHead): + """Class-specific residual attention classifier head. + + Please refer to the `Residual Attention: A Simple but Effective Method for + Multi-Label Recognition (ICCV 2021) `_ + for details. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + num_heads (int): Number of residual at tensor heads. + loss (dict): Config of classification loss. + lam (float): Lambda that combines global average and max pooling + scores. + init_cfg (dict, optional): The extra init config of layers. + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + temperature_settings = { # softmax temperature settings + 1: [1], + 2: [1, 99], + 4: [1, 2, 4, 99], + 6: [1, 2, 3, 4, 5, 99], + 8: [1, 2, 3, 4, 5, 6, 7, 99] + } + + def __init__(self, + num_classes: int, + in_channels: int, + num_heads: int, + lam: float, + init_cfg=dict(type='Normal', layer='Linear', std=0.01), + **kwargs): + assert num_heads in self.temperature_settings.keys( + ), 'The num of heads is not in temperature setting.' + assert lam > 0, 'Lambda should be between 0 and 1.' + super(CSRAClsHead, self).__init__(init_cfg=init_cfg, **kwargs) + self.temp_list = self.temperature_settings[num_heads] + self.csra_heads = ModuleList([ + CSRAModule(num_classes, in_channels, self.temp_list[i], lam) + for i in range(num_heads) + ]) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``CSRAClsHead``, we just obtain the + feature of the last stage. + """ + # The CSRAClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + logit = sum([head(pre_logits) for head in self.csra_heads]) + return logit + + +class CSRAModule(BaseModule): + """Basic module of CSRA with different temperature. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + T (int): Temperature setting. + lam (float): Lambda that combines global average and max pooling + scores. + init_cfg (dict | optional): The extra init config of layers. + Defaults to use dict(type='Normal', layer='Linear', std=0.01). + """ + + def __init__(self, + num_classes: int, + in_channels: int, + T: int, + lam: float, + init_cfg=None): + + super(CSRAModule, self).__init__(init_cfg=init_cfg) + self.T = T # temperature + self.lam = lam # Lambda + self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False) + self.softmax = nn.Softmax(dim=2) + + def forward(self, x): + score = self.head(x) / torch.norm( + self.head.weight, dim=1, keepdim=True).transpose(0, 1) + score = score.flatten(2) + base_logit = torch.mean(score, dim=2) + + if self.T == 99: # max-pooling + att_logit = torch.max(score, dim=2)[0] + else: + score_soft = self.softmax(score * self.T) + att_logit = torch.sum(score * score_soft, dim=2) + + return base_logit + self.lam * att_logit diff --git a/mmpretrain/models/heads/multi_label_linear_head.py b/mmpretrain/models/heads/multi_label_linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..81217ec55c54f23748b7e4ce8797509abfbb2ed3 --- /dev/null +++ b/mmpretrain/models/heads/multi_label_linear_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .multi_label_cls_head import MultiLabelClsHead + + +@MODELS.register_module() +class MultiLabelLinearClsHead(MultiLabelClsHead): + """Linear classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to use dict(type='Normal', layer='Linear', std=0.01). + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01)): + super(MultiLabelLinearClsHead, self).__init__( + loss=loss, thr=thr, topk=topk, init_cfg=init_cfg) + + assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \ + 'positive integer.' + + self.in_channels = in_channels + self.num_classes = num_classes + + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just + obtain the feature of the last stage. + """ + # The obtain the MultiLabelLinearClsHead doesn't have other module, + # just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/multi_task_head.py b/mmpretrain/models/heads/multi_task_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4645a790d8494a216d945c91496388e0629c79 --- /dev/null +++ b/mmpretrain/models/heads/multi_task_head.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule, ModuleDict + +from mmpretrain.registry import MODELS +from mmpretrain.structures import MultiTaskDataSample + + +def loss_convertor(loss_func, task_name): + + def wrapped(inputs, data_samples, **kwargs): + mask = torch.empty(len(data_samples), dtype=torch.bool) + task_data_samples = [] + for i, data_sample in enumerate(data_samples): + assert isinstance(data_sample, MultiTaskDataSample) + sample_mask = task_name in data_sample + mask[i] = sample_mask + if sample_mask: + task_data_samples.append(data_sample.get(task_name)) + + if len(task_data_samples) == 0: + # This makes it possible to perform loss.backward when a + # task does not have gt_labels within a batch. + loss = (inputs[0] * 0).sum() + return {'loss': loss, 'mask_size': torch.tensor(0.)} + + # Mask the inputs of the task + def mask_inputs(inputs, mask): + if isinstance(inputs, Sequence): + return type(inputs)( + [mask_inputs(input, mask) for input in inputs]) + elif isinstance(inputs, torch.Tensor): + return inputs[mask] + + masked_inputs = mask_inputs(inputs, mask) + loss_output = loss_func(masked_inputs, task_data_samples, **kwargs) + loss_output['mask_size'] = mask.sum().to(torch.float) + return loss_output + + return wrapped + + +@MODELS.register_module() +class MultiTaskHead(BaseModule): + """Multi task head. + + Args: + task_heads (dict): Sub heads to use, the key will be use to rename the + loss components. + common_cfg (dict): The common settings for all heads. Defaults to an + empty dict. + init_cfg (dict, optional): The extra initialization settings. + Defaults to None. + """ + + def __init__(self, task_heads, init_cfg=None, **kwargs): + super(MultiTaskHead, self).__init__(init_cfg=init_cfg) + + assert isinstance(task_heads, dict), 'The `task_heads` argument' \ + "should be a dict, which's keys are task names and values are" \ + 'configs of head for the task.' + + self.task_heads = ModuleDict() + + for task_name, sub_head in task_heads.items(): + if not isinstance(sub_head, nn.Module): + sub_head = MODELS.build(sub_head, default_args=kwargs) + sub_head.loss = loss_convertor(sub_head.loss, task_name) + self.task_heads[task_name] = sub_head + + def forward(self, feats): + """The forward process.""" + return { + task_name: head(feats) + for task_name, head in self.task_heads.items() + } + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components, each task loss + key will be prefixed by the task_name like "task1_loss" + """ + losses = dict() + for task_name, head in self.task_heads.items(): + head_loss = head.loss(feats, data_samples, **kwargs) + for k, v in head_loss.items(): + losses[f'{task_name}_{k}'] = v + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample] = None + ) -> List[MultiTaskDataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[MultiTaskDataSample]: A list of data samples which contains + the predicted results. + """ + predictions_dict = dict() + + for task_name, head in self.task_heads.items(): + task_samples = head.predict(feats) + batch_size = len(task_samples) + predictions_dict[task_name] = task_samples + + if data_samples is None: + data_samples = [MultiTaskDataSample() for _ in range(batch_size)] + + for task_name, task_samples in predictions_dict.items(): + for data_sample, task_sample in zip(data_samples, task_samples): + task_sample.set_field( + task_name in data_sample.tasks, + 'eval_mask', + field_type='metainfo') + + if task_name in data_sample.tasks: + data_sample.get(task_name).update(task_sample) + else: + data_sample.set_field(task_sample, task_name) + + return data_samples diff --git a/mmpretrain/models/heads/seq_gen_head.py b/mmpretrain/models/heads/seq_gen_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e9b10efe6e1e6a709cd870f0572f14bbd176ee --- /dev/null +++ b/mmpretrain/models/heads/seq_gen_head.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SeqGenerationHead(BaseModule): + """Generation head for multi-modal pre-trained task, adopted by BLIP. + Normally used for generation task. + + Args: + decoder (dict): Decoder for blip generation head. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + decoder: dict, + ignore_index=-100, + loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1), + init_cfg: Optional[dict] = None, + ) -> None: + super(SeqGenerationHead, self).__init__(init_cfg=init_cfg) + self.decoder = MODELS.build(decoder) + self.loss_fn = MODELS.build(loss) + self.ignore_index = ignore_index + + def forward(self, input_ids: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, labels: torch.Tensor): + """Forward to get decoder output. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + encoder_attention_mask (torch.Tensor): Image embeddings hidden + states attention mask. + labels (torch.Tensor): Decoder target for calculate loss. + + Returns: + dict[str, Tensor]: a dictionary of decoder outputs. + """ + + decoder_out = self.decoder( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + return_dict=True, + ) + return decoder_out + + def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask, + labels): + """Calculate losses from the extracted features. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + encoder_attention_mask (torch.Tensor): Image embeddings hidden + states attention mask. + labels (torch.Tensor): Decoder target for calculate loss. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + + decoder_out = self( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + ) + prediction_scores = decoder_out['logits'] + # we are doing next-token prediction; + # shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + vocab_size = prediction_scores.shape[-1] + + # mask ignored index + if (labels == self.ignore_index).any(): + labels = labels.view(-1).clone() + ignore_mask = (labels == self.ignore_index) + labels.masked_fill_(ignore_mask, 0) + weight = torch.logical_not(ignore_mask) + avg_factor = max(weight.sum(), 1) + else: + weight = None + avg_factor = labels.size(0) + + lm_loss = self.loss_fn( + shifted_prediction_scores.view(-1, vocab_size), + labels, + weight=weight, + avg_factor=avg_factor, + ) + losses = { + 'seq_gen_lm_loss': lm_loss, + } + + return losses + + def predict(self, + input_ids, + encoder_hidden_states, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=20, + min_length=2, + top_p=0.9, + repetition_penalty=1.0, + **kwargs): + """Decoder prediction method. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + sep_token_id (int): Tokenid of separation token. + pad_token_id (int): Tokenid of pad token. + use_nucleus_sampling (bool): Whether to use nucleus sampling in + prediction. Defaults to False. + num_beams (int): Number of beams used in predition. + Defaults to 3. + max_length (int): Max length of generated text in predition. + Defaults to 20. + min_length (int): Min length of generated text in predition. + Defaults to 20. + top_p (float): + If < 1.0, only keep the top tokens with cumulative probability + >= top_p (nucleus filtering). Defaults to 0.9. + repetition_penalty (float): The parameter for repetition penalty. + Defaults to 1.0. + **kwarg: Other arguments that might used in generation. + + Returns: + dict[str, Tensor]: a dictionary of generation outputs. + """ + device = encoder_hidden_states.device + + # TODO: In old version of transformers + # Additional repeat interleave of hidden states should be add here. + image_atts = torch.ones( + encoder_hidden_states.size()[:-1], dtype=torch.long).to(device) + + model_kwargs = { + 'encoder_hidden_states': encoder_hidden_states, + 'encoder_attention_mask': image_atts, + } + model_kwargs.update(kwargs) + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + outputs = self.decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + return outputs diff --git a/mmpretrain/models/heads/simmim_head.py b/mmpretrain/models/heads/simmim_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b7af984c9eb4891e9f4281daf630355cafbb6cc7 --- /dev/null +++ b/mmpretrain/models/heads/simmim_head.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMHead(BaseModule): + """Head for SimMIM Pre-training. + + Args: + patch_size (int): Patch size of each token. + loss (dict): The config for loss. + """ + + def __init__(self, patch_size: int, loss: dict) -> None: + super().__init__() + self.patch_size = patch_size + self.loss_module = MODELS.build(loss) + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + This method will expand mask to the size of the original image. + + Args: + pred (torch.Tensor): The reconstructed image (B, C, H, W). + target (torch.Tensor): The target image (B, C, H, W). + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave( + self.patch_size, 2).unsqueeze(1).contiguous() + loss = self.loss_module(pred, target, mask) + + return loss diff --git a/mmpretrain/models/heads/spark_head.py b/mmpretrain/models/heads/spark_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a2748762ae50e1e085bd2ce536e95c6d52e51d9c --- /dev/null +++ b/mmpretrain/models/heads/spark_head.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SparKPretrainHead(BaseModule): + """Pre-training head for SparK. + + Args: + loss (dict): Config of loss. + norm_pix (bool): Whether or not normalize target. Defaults to True. + patch_size (int): Patch size, equal to downsample ratio of backbone. + Defaults to 32. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = True, + patch_size: int = 32) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.loss = MODELS.build(loss) + + def patchify(self, imgs): + """Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images, of shape B x C x H x W. + Returns: + torch.Tensor: Patchified images. The shape is B x L x D. + """ + p = self.patch_size + assert len(imgs.shape + ) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0 + + B, C, ori_h, ori_w = imgs.shape + h = ori_h // p + w = ori_w // p + x = imgs.reshape(shape=(B, C, h, p, w, p)) + x = torch.einsum('bchpwq->bhwpqc', x) + + # (B, f*f, downsample_raito*downsample_raito*3) + x = x.reshape(shape=(B, h * w, p**2 * C)) + return x + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + Args: + target (torch.Tensor): Image with the shape of B x 3 x H x W + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + active_mask: torch.Tensor) -> torch.Tensor: + """Forward function of MAE head. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + active_mask (torch.Tensor): The mask of the target image. + Returns: + torch.Tensor: The reconstruction loss. + """ + # (B, C, H, W) -> (B, L, C) and perform normalization + target = self.construct_target(target) + + # (B, C, H, W) -> (B, L, C) + pred = self.patchify(pred) + + # (B, 1, f, f) -> (B, L) + non_active_mask = active_mask.logical_not().int().view( + active_mask.shape[0], -1) + + # MSE loss on masked patches + loss = self.loss(pred, target, non_active_mask) + return loss diff --git a/mmpretrain/models/heads/stacked_head.py b/mmpretrain/models/heads/stacked_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd819de8e8daf162bb906d5524871577754fa1f --- /dev/null +++ b/mmpretrain/models/heads/stacked_head.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +class LinearBlock(BaseModule): + """Linear block for StackedLinearClsHead.""" + + def __init__(self, + in_channels, + out_channels, + dropout_rate=0., + norm_cfg=None, + act_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.fc = nn.Linear(in_channels, out_channels) + + self.norm = None + self.act = None + self.dropout = None + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + if act_cfg is not None: + self.act = build_activation_layer(act_cfg) + if dropout_rate > 0: + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x): + """The forward process.""" + x = self.fc(x) + if self.norm is not None: + x = self.norm(x) + if self.act is not None: + x = self.act(x) + if self.dropout is not None: + x = self.dropout(x) + return x + + +@MODELS.register_module() +class StackedLinearClsHead(ClsHead): + """Classifier head with several hidden fc layer and a output fc layer. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + mid_channels (Sequence[int]): Number of channels in the hidden fc + layers. + dropout_rate (float): Dropout rate after each hidden fc layer, + except the last layer. Defaults to 0. + norm_cfg (dict, optional): Config dict of normalization layer after + each hidden fc layer, except the last layer. Defaults to None. + act_cfg (dict, optional): Config dict of activation function after each + hidden layer, except the last layer. Defaults to use "ReLU". + """ + + def __init__(self, + num_classes: int, + in_channels: int, + mid_channels: Sequence[int], + dropout_rate: float = 0., + norm_cfg: Optional[Dict] = None, + act_cfg: Optional[Dict] = dict(type='ReLU'), + **kwargs): + super(StackedLinearClsHead, self).__init__(**kwargs) + self.num_classes = num_classes + self.in_channels = in_channels + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + assert isinstance(mid_channels, Sequence), \ + f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \ + f'instead of {type(mid_channels)}' + self.mid_channels = mid_channels + + self.dropout_rate = dropout_rate + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self._init_layers() + + def _init_layers(self): + """"Init layers.""" + self.layers = ModuleList() + in_channels = self.in_channels + for hidden_channels in self.mid_channels: + self.layers.append( + LinearBlock( + in_channels, + hidden_channels, + dropout_rate=self.dropout_rate, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channels = hidden_channels + + self.layers.append( + LinearBlock( + self.mid_channels[-1], + self.num_classes, + dropout_rate=0., + norm_cfg=None, + act_cfg=None)) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. + """ + x = feats[-1] + for layer in self.layers[:-1]: + x = layer(x) + return x + + @property + def fc(self): + """Full connected layer.""" + return self.layers[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/swav_head.py b/mmpretrain/models/heads/swav_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3a30236e019822a166e25551f77feec8228d84 --- /dev/null +++ b/mmpretrain/models/heads/swav_head.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SwAVHead(BaseModule): + """Head for SwAV Pre-training. + + Args: + loss (dict): Config dict for module of loss functions. + """ + + def __init__(self, loss: dict) -> None: + super().__init__() + self.loss_module = MODELS.build(loss) + + def loss(self, pred: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): NxC input features. + + Returns: + torch.Tensor: The SwAV loss. + """ + loss = self.loss_module(pred) + + return loss diff --git a/mmpretrain/models/heads/vig_head.py b/mmpretrain/models/heads/vig_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb984deb4b0b6bf162263a86771f2d3eba71cbd --- /dev/null +++ b/mmpretrain/models/heads/vig_head.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class VigClsHead(ClsHead): + """The classification head for Vision GNN. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int): The number of middle channels. Defaults to 1024. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + dropout (float): The dropout rate. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + hidden_dim: int = 1024, + act_cfg: dict = dict(type='GELU'), + dropout: float = 0., + **kwargs): + super().__init__(**kwargs) + + self.fc1 = nn.Linear(in_channels, hidden_dim) + self.bn = nn.BatchNorm1d(hidden_dim) + self.act = build_activation_layer(act_cfg) + self.drop = nn.Dropout(dropout) + self.fc2 = nn.Linear(hidden_dim, num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a stage_blocks stage. In ``VigClsHead``, we just obtain the + feature of the last stage. + """ + feats = feats[-1] + feats = self.fc1(feats) + feats = self.bn(feats) + feats = self.act(feats) + feats = self.drop(feats) + + return feats + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc2(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/vision_transformer_head.py b/mmpretrain/models/heads/vision_transformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..83e8fca125cd626c51abfcc87b28387f654618f9 --- /dev/null +++ b/mmpretrain/models/heads/vision_transformer_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from collections import OrderedDict +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer +from mmengine.model import Sequential +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class VisionTransformerClsHead(ClsHead): + """Vision Transformer classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + hidden_dim: Optional[int] = None, + act_cfg: dict = dict(type='Tanh'), + init_cfg: dict = dict(type='Constant', layer='Linear', val=0), + **kwargs): + super(VisionTransformerClsHead, self).__init__( + init_cfg=init_cfg, **kwargs) + self.in_channels = in_channels + self.num_classes = num_classes + self.hidden_dim = hidden_dim + self.act_cfg = act_cfg + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self._init_layers() + + def _init_layers(self): + """"Init hidden layer if exists.""" + if self.hidden_dim is None: + layers = [('head', nn.Linear(self.in_channels, self.num_classes))] + else: + layers = [ + ('pre_logits', nn.Linear(self.in_channels, self.hidden_dim)), + ('act', build_activation_layer(self.act_cfg)), + ('head', nn.Linear(self.hidden_dim, self.num_classes)), + ] + self.layers = Sequential(OrderedDict(layers)) + + def init_weights(self): + """"Init weights of hidden layer if exists.""" + super(VisionTransformerClsHead, self).init_weights() + # Modified from ClassyVision + if hasattr(self.layers, 'pre_logits'): + # Lecun norm + trunc_normal_( + self.layers.pre_logits.weight, + std=math.sqrt(1 / self.layers.pre_logits.in_features)) + nn.init.zeros_(self.layers.pre_logits.bias) + + def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``VisionTransformerClsHead``, we + obtain the feature of the last stage and forward in hidden layer if + exists. + """ + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + cls_token = feat[-1] if isinstance(feat, list) else feat + if self.hidden_dim is None: + return cls_token + else: + x = self.layers.pre_logits(cls_token) + return self.layers.act(x) + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.layers.head(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/vqa_head.py b/mmpretrain/models/heads/vqa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b5fe532874e2e8325caa3090d3be66b098ad46 --- /dev/null +++ b/mmpretrain/models/heads/vqa_head.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import mmengine +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class VQAGenerationHead(BaseModule): + """Generation head for multi-modal pre-trained task, adapted by BLIP. + Normally used for qa generation task (open-set) + + Args: + decoder (dict): Decoder for decoding answers. + inference_method (str): Inference method. One of 'rank', 'generate'. + - If 'rank', the model will return answers with the highest + probability from the answer list. + - If 'generate', the model will generate answers. + - Only for test, not for train / val. + num_beams (int): Number of beams for beam search. 1 means no beam + search. Only support when inference_method=='generate'. + Defaults to 3. + num_ans_candidates (int): Number of answer candidates, used to filter + out answers with low probability. Only support when + inference_method=='rank'. Defaults to 128. + loss (dict or nn.Module): Config of loss or module of loss. Defaults to + ``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + answer_list_path (str, optional): Path to `answer_list.json` + (json file of a answer list). Required when + inference_method=='rank'. + + + TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param. + Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to + maintain compatibility with torch < 1.10.0 + """ + + def __init__( + self, + decoder: dict, + inference_method: str = 'generate', + num_beams: int = 3, + num_ans_candidates: int = 128, + loss: Union[dict, nn.Module] = nn.CrossEntropyLoss( + reduction='none', ignore_index=-100), + init_cfg: Optional[dict] = None, + answer_list_path: Optional[str] = None, + ) -> None: + + super(VQAGenerationHead, self).__init__(init_cfg=init_cfg) + self.decoder = MODELS.build(decoder) + + if inference_method == 'generate': + assert isinstance(num_beams, int), \ + 'for VQA `generate` mode, `num_beams` must be a int.' + self.num_beams = num_beams + self.num_ans_candidates = None + self.answer_list = None + + elif inference_method == 'rank': + assert isinstance(num_ans_candidates, int), \ + 'for VQA `rank` mode, `num_ans_candidates` must be a int.' + assert isinstance(answer_list_path, str), \ + 'for VQA `rank` mode, `answer_list_path` must be set as ' \ + 'the path to `answer_list.json`.' + self.num_beams = None + self.answer_list = mmengine.load(answer_list_path) + if isinstance(self.answer_list, dict): + self.answer_list = list(self.answer_list.keys()) + assert isinstance(self.answer_list, list) and all( + isinstance(item, str) for item in self.answer_list), \ + 'for VQA `rank` mode, `answer_list.json` must be a list of str' + self.num_ans_candidates = min(num_ans_candidates, + len(self.answer_list)) + + else: + raise AssertionError( + 'for VQA, `inference_method` must be "generate" or "rank", ' + 'got {}.'.format(inference_method)) + + self.inference_method = inference_method + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + def forward(self, feats: dict): + prediction_logits = self.decoder( + feats['answer_input_ids'], + attention_mask=feats['answer_attention_mask'], + encoder_hidden_states=feats['question_states'], + encoder_attention_mask=feats['question_atts'], + labels=feats['answer_targets'], + return_dict=True, + return_logits=True, # directly return logits, not computing loss + reduction='none', + ) + return prediction_logits + + def loss(self, feats: dict, data_samples=None): + """Calculate losses from the extracted features. + + Args: + feats (dict): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + shifted_prediction_scores = self(feats) + labels = feats['answer_targets'] + lm_loss = None + + # we are doing next-token prediction; + # shift prediction scores and input ids by one + labels = labels[:, 1:].contiguous() + lm_loss = self.loss_module( + shifted_prediction_scores.view(-1, + self.decoder.med_config.vocab_size), + labels.view(-1)) + lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1) + # compute weighted loss + losses = dict() + loss = feats['answer_weight'] * lm_loss + loss = loss.sum() / feats['batch_size'] + losses['vqa_loss'] = loss + + return losses + + def predict_rank(self, feats: dict, data_samples=None): + """Predict rank in a close-set answer list.""" + question_states = feats['multimodal_embeds'] + question_atts = feats['question_atts'] + answer_candidates = feats['answer_candidates'] + assert answer_candidates is not None + + answer_ids = answer_candidates.input_ids + answer_atts = answer_candidates.attention_mask + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction='none', + ) + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax( + logits, dim=1).index_select( + dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk( + self.num_ans_candidates, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'], + -100) + + def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([ + init_dim * np.arange(n_tile) + i for i in range(init_dim) + ])) + return torch.index_select(x, dim, order_index.to(x.device)) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, self.num_ans_candidates) + question_atts = tile(question_atts, 0, self.num_ans_candidates) + + output = self.decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction='none', + ) + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] + + answers = [self.answer_list[max_id] for max_id in max_ids] + + return answers + + def predict_generate(self, feats: dict, data_samples=None): + """Predict answers in a generation manner.""" + device = feats['multimodal_embeds'].device + question_states = feats['multimodal_embeds'] + question_atts = torch.ones( + question_states.size()[:-1], dtype=torch.long).to(device) + model_kwargs = { + 'encoder_hidden_states': question_states, + 'encoder_attention_mask': question_atts + } + + bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1), + fill_value=feats['bos_token_id'], + device=device) + + outputs = self.decoder.generate( + input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=self.num_beams, + eos_token_id=feats['sep_token_id'], + pad_token_id=feats['pad_token_id'], + **model_kwargs) + + return outputs + + def predict(self, feats: dict, data_samples=None): + """Predict results from the extracted features.""" + if self.inference_method == 'generate': + return self.predict_generate(feats, data_samples) + elif self.inference_method == 'rank': + return self.predict_rank(feats, data_samples) diff --git a/mmpretrain/models/losses/__init__.py b/mmpretrain/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b2ed725ef76df7e18bf9283ec84b3b12e3d2cf --- /dev/null +++ b/mmpretrain/models/losses/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .asymmetric_loss import AsymmetricLoss, asymmetric_loss +from .cae_loss import CAELoss +from .cosine_similarity_loss import CosineSimilarityLoss +from .cross_correlation_loss import CrossCorrelationLoss +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy) +from .focal_loss import FocalLoss, sigmoid_focal_loss +from .label_smooth_loss import LabelSmoothLoss +from .reconstruction_loss import PixelReconstructionLoss +from .seesaw_loss import SeesawLoss +from .swav_loss import SwAVLoss +from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, + weighted_loss) + +__all__ = [ + 'asymmetric_loss', + 'AsymmetricLoss', + 'cross_entropy', + 'binary_cross_entropy', + 'CrossEntropyLoss', + 'reduce_loss', + 'weight_reduce_loss', + 'LabelSmoothLoss', + 'weighted_loss', + 'FocalLoss', + 'sigmoid_focal_loss', + 'convert_to_one_hot', + 'SeesawLoss', + 'CAELoss', + 'CosineSimilarityLoss', + 'CrossCorrelationLoss', + 'PixelReconstructionLoss', + 'SwAVLoss', +] diff --git a/mmpretrain/models/losses/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94674848f05fdbfa39aaabc3e579112e83c47f2 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/asymmetric_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/asymmetric_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8019a3c23951f409c171f3fd33a383bfe88d3e5 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/asymmetric_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cae_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/cae_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45bc226be1b3310df1ee09fcc1f44d7c2d8be18b Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cae_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cosine_similarity_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/cosine_similarity_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da8ea49fab4a210841854e11e89a3c6858565173 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cosine_similarity_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cross_correlation_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/cross_correlation_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d84b345a4f9e6b4bb41210f1ecb25377e59e564b Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cross_correlation_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daec8b46a1a39ed6d75912625996b6033dfffd36 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/focal_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/focal_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520f0b1d0dffd1e488ff1c845b6ec8cee3ae6b4a Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/focal_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/label_smooth_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/label_smooth_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..476532d1b415c6155b1c79f1dd250badb0db52e2 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/label_smooth_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/reconstruction_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/reconstruction_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fd64c95b56206bb5bdc184ec7ffa6d02c319230 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/reconstruction_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/seesaw_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/seesaw_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c6542aed8aabf638324f2740f4074c499e30f03 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/seesaw_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/swav_loss.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/swav_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05bc0118ecf71eab64565bf7bed61bb305cd7552 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/swav_loss.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/utils.cpython-38.pyc b/mmpretrain/models/losses/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8124b50532bd2f62a9295d615569f242f9d16e25 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpretrain/models/losses/asymmetric_loss.py b/mmpretrain/models/losses/asymmetric_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc9707da8475b5e87d2b4f8a5a2cf669d7ffe2f --- /dev/null +++ b/mmpretrain/models/losses/asymmetric_loss.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .utils import convert_to_one_hot, weight_reduce_loss + + +def asymmetric_loss(pred, + target, + weight=None, + gamma_pos=1.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + avg_factor=None, + use_sigmoid=True, + eps=1e-8): + r"""asymmetric loss. + + Please refer to the `paper `__ for + details. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Defaults to None. + gamma_pos (float): positive focusing parameter. Defaults to 0.0. + gamma_neg (float): Negative focusing parameter. We usually set + gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + use_sigmoid (bool): Whether the prediction uses sigmoid instead + of softmax. Defaults to True. + eps (float): The minimum value of the argument of logarithm. Defaults + to 1e-8. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + + if use_sigmoid: + pred_sigmoid = pred.sigmoid() + else: + pred_sigmoid = nn.functional.softmax(pred, dim=-1) + + target = target.type_as(pred) + + if clip and clip > 0: + pt = (1 - pred_sigmoid + + clip).clamp(max=1) * (1 - target) + pred_sigmoid * target + else: + pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target + asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg * + (1 - target)) + loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class AsymmetricLoss(nn.Module): + """asymmetric loss. + + Args: + gamma_pos (float): positive focusing parameter. + Defaults to 0.0. + gamma_neg (float): Negative focusing parameter. We + usually set gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str): The method used to reduce the loss into + a scalar. + loss_weight (float): Weight of loss. Defaults to 1.0. + use_sigmoid (bool): Whether the prediction uses sigmoid instead + of softmax. Defaults to True. + eps (float): The minimum value of the argument of logarithm. Defaults + to 1e-8. + """ + + def __init__(self, + gamma_pos=0.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + loss_weight=1.0, + use_sigmoid=True, + eps=1e-8): + super(AsymmetricLoss, self).__init__() + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.clip = clip + self.reduction = reduction + self.loss_weight = loss_weight + self.use_sigmoid = use_sigmoid + self.eps = eps + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + r"""asymmetric loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction + with shape (N, \*), N or (N,1). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): + target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) + loss_cls = self.loss_weight * asymmetric_loss( + pred, + target, + weight, + gamma_pos=self.gamma_pos, + gamma_neg=self.gamma_neg, + clip=self.clip, + reduction=reduction, + avg_factor=avg_factor, + use_sigmoid=self.use_sigmoid, + eps=self.eps) + return loss_cls diff --git a/mmpretrain/models/losses/cae_loss.py b/mmpretrain/models/losses/cae_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc081b603361e9b06c96cf836941fa971a4b4c4 --- /dev/null +++ b/mmpretrain/models/losses/cae_loss.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CAELoss(BaseModule): + """Loss function for CAE. + + Compute the align loss and the main loss. + + Args: + lambd (float): The weight for the align loss. + """ + + def __init__(self, lambd: float) -> None: + super().__init__() + self.lambd = lambd + self.loss_cross_entropy = nn.CrossEntropyLoss() + self.loss_mse = nn.MSELoss() + + def forward( + self, logits: torch.Tensor, target: torch.Tensor, + latent_pred: torch.Tensor, + latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function of CAE Loss. + + Args: + logits (torch.Tensor): The outputs from the decoder. + target (torch.Tensor): The targets generated by dalle. + latent_pred (torch.Tensor): The latent prediction from the + regressor. + latent_target (torch.Tensor): The latent target from the teacher + network. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss. + """ + loss_main = self.loss_cross_entropy(logits, target) + loss_align = self.loss_mse(latent_pred, + latent_target.detach()) * self.lambd + + return loss_main, loss_align diff --git a/mmpretrain/models/losses/cosine_similarity_loss.py b/mmpretrain/models/losses/cosine_similarity_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a5931e24686bd560196e1e310fc283fc4c9d4d --- /dev/null +++ b/mmpretrain/models/losses/cosine_similarity_loss.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineSimilarityLoss(BaseModule): + """Cosine similarity loss function. + + Compute the similarity between two features and optimize that similarity as + loss. + + Args: + shift_factor (float): The shift factor of cosine similarity. + Default: 0.0. + scale_factor (float): The scale factor of cosine similarity. + Default: 1.0. + """ + + def __init__(self, + shift_factor: float = 0.0, + scale_factor: float = 1.0) -> None: + super().__init__() + self.shift_factor = shift_factor + self.scale_factor = scale_factor + + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function of cosine similarity loss. + + Args: + pred (torch.Tensor): The predicted features. + target (torch.Tensor): The target features. + + Returns: + torch.Tensor: The cosine similarity loss. + """ + pred_norm = nn.functional.normalize(pred, dim=-1) + target_norm = nn.functional.normalize(target, dim=-1) + loss = self.shift_factor - self.scale_factor * ( + pred_norm * target_norm).sum(dim=-1) + + if mask is None: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() + return loss diff --git a/mmpretrain/models/losses/cross_correlation_loss.py b/mmpretrain/models/losses/cross_correlation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d26ce3ddbd7b41778cbf25147df39da256788dd1 --- /dev/null +++ b/mmpretrain/models/losses/cross_correlation_loss.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CrossCorrelationLoss(BaseModule): + """Cross correlation loss function. + + Compute the on-diagnal and off-diagnal loss. + + Args: + lambd (float): The weight for the off-diag loss. + """ + + def __init__(self, lambd: float = 0.0051) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor: + """Forward function of cross correlation loss. + + Args: + cross_correlation_matrix (torch.Tensor): The cross correlation + matrix. + + Returns: + torch.Tensor: cross correlation loss. + """ + # loss + on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_( + 2).sum() + off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum() + loss = on_diag + self.lambd * off_diag + return loss + + def off_diagonal(self, x: torch.Tensor) -> torch.Tensor: + """Rreturn a flattened view of the off-diagonal elements of a square + matrix.""" + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/mmpretrain/models/losses/cross_entropy_loss.py b/mmpretrain/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5d418beb812f8493668aeff99198555068a55435 --- /dev/null +++ b/mmpretrain/models/losses/cross_entropy_loss.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None): + """Calculate the CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The gt label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def soft_cross_entropy(pred, + label, + weight=None, + reduction='mean', + class_weight=None, + avg_factor=None): + """Calculate the Soft CrossEntropy loss. The label can be float. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The gt label of the prediction with shape (N, C). + When using "mixup", the label can be float. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # element-wise losses + loss = -label * F.log_softmax(pred, dim=-1) + if class_weight is not None: + loss *= class_weight + loss = loss.sum(dim=-1) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + pos_weight=None): + r"""Calculate the binary CrossEntropy loss with logits. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + label (torch.Tensor): The gt label with shape (N, \*). + weight (torch.Tensor, optional): Element-wise weight of loss with shape + (N, ). Defaults to None. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + pos_weight (torch.Tensor, optional): The positive weight for each + class with shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # Ensure that the size of class_weight is consistent with pred and label to + # avoid automatic boracast, + assert pred.dim() == label.dim() + + if class_weight is not None: + N = pred.size()[0] + class_weight = class_weight.repeat(N, 1) + loss = F.binary_cross_entropy_with_logits( + pred, + label.float(), # only accepts float type tensor + weight=class_weight, + pos_weight=pos_weight, + reduction='none') + + # apply weights and do the reduction + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """Cross entropy loss. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_soft (bool): Whether to use the soft version of CrossEntropyLoss. + Defaults to False. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". Defaults to 'mean'. + loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (List[float], optional): The weight for each class with + shape (C), C is the number of classes. Default None. + pos_weight (List[float], optional): The positive weight for each + class with shape (C), C is the number of classes. Only enabled in + BCE loss when ``use_sigmoid`` is True. Default None. + """ + + def __init__(self, + use_sigmoid=False, + use_soft=False, + reduction='mean', + loss_weight=1.0, + class_weight=None, + pos_weight=None): + super(CrossEntropyLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.use_soft = use_soft + assert not ( + self.use_soft and self.use_sigmoid + ), 'use_sigmoid and use_soft could not be set simultaneously' + + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.pos_weight = pos_weight + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_soft: + self.cls_criterion = soft_cross_entropy + else: + self.cls_criterion = cross_entropy + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # only BCE loss has pos_weight + if self.pos_weight is not None and self.use_sigmoid: + pos_weight = cls_score.new_tensor(self.pos_weight) + kwargs.update({'pos_weight': pos_weight}) + else: + pos_weight = None + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/mmpretrain/models/losses/focal_loss.py b/mmpretrain/models/losses/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2cf5035aedfd923ae388b264a7457312b274fd --- /dev/null +++ b/mmpretrain/models/losses/focal_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import convert_to_one_hot, weight_reduce_loss + + +def sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + r"""Sigmoid focal loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Defaults to None. + gamma (float): The gamma for calculating the modulating factor. + Defaults to 2.0. + alpha (float): A balanced form for Focal Loss. Defaults to 0.25. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , + loss is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + """Focal loss. + + Args: + gamma (float): Focusing parameter in focal loss. + Defaults to 2.0. + alpha (float): The parameter in balanced form of focal + loss. Defaults to 0.25. + reduction (str): The method used to reduce the loss into + a scalar. Options are "none" and "mean". Defaults to 'mean'. + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0): + + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + r"""Sigmoid focal loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction + with shape (N, \*), N or (N,1). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): + target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) + loss_cls = self.loss_weight * sigmoid_focal_loss( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + return loss_cls diff --git a/mmpretrain/models/losses/label_smooth_loss.py b/mmpretrain/models/losses/label_smooth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f117df33b07c05ee7516f0b99d985f0d001b2d31 --- /dev/null +++ b/mmpretrain/models/losses/label_smooth_loss.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .cross_entropy_loss import CrossEntropyLoss +from .utils import convert_to_one_hot + + +@MODELS.register_module() +class LabelSmoothLoss(nn.Module): + r"""Initializer for the label smoothed cross entropy loss. + + Refers to `Rethinking the Inception Architecture for Computer Vision + `_ + + This decreases gap between output scores and encourages generalization. + Labels provided to forward can be one-hot like vectors (NxC) or class + indices (Nx1). + And this accepts linear combination of one-hot like labels from mixup or + cutmix except multi-label task. + + Args: + label_smooth_val (float): The degree of label smoothing. + num_classes (int, optional): Number of classes. Defaults to None. + mode (str): Refers to notes, Options are 'original', 'classy_vision', + 'multi_label'. Defaults to 'original'. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid of + softmax. Defaults to None, which means to use sigmoid in + "multi_label" mode and not use in other modes. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". Defaults to 'mean'. + loss_weight (float): Weight of the loss. Defaults to 1.0. + + Notes: + - if the mode is **"original"**, this will use the same label smooth + method as the original paper as: + + .. math:: + (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K} + + where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the + ``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which + equals 1 for :math:`k=y` and 0 otherwise. + + - if the mode is **"classy_vision"**, this will use the same label + smooth method as the facebookresearch/ClassyVision repo as: + + .. math:: + \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon} + + - if the mode is **"multi_label"**, this will accept labels from + multi-label task and smoothing them as: + + .. math:: + (1-2\epsilon)\delta_{k, y} + \epsilon + """ + + def __init__(self, + label_smooth_val, + num_classes=None, + use_sigmoid=None, + mode='original', + reduction='mean', + loss_weight=1.0, + class_weight=None, + pos_weight=None): + super().__init__() + self.num_classes = num_classes + self.loss_weight = loss_weight + + assert (isinstance(label_smooth_val, float) + and 0 <= label_smooth_val < 1), \ + f'LabelSmoothLoss accepts a float label_smooth_val ' \ + f'over [0, 1), but gets {label_smooth_val}' + self.label_smooth_val = label_smooth_val + + accept_reduction = {'none', 'mean', 'sum'} + assert reduction in accept_reduction, \ + f'LabelSmoothLoss supports reduction {accept_reduction}, ' \ + f'but gets {mode}.' + self.reduction = reduction + + accept_mode = {'original', 'classy_vision', 'multi_label'} + assert mode in accept_mode, \ + f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.' + self.mode = mode + + self._eps = label_smooth_val + if mode == 'classy_vision': + self._eps = label_smooth_val / (1 + label_smooth_val) + + if mode == 'multi_label': + if not use_sigmoid: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().warning( + 'For multi-label tasks, please set `use_sigmoid=True` ' + 'to use binary cross entropy.') + self.smooth_label = self.multilabel_smooth_label + use_sigmoid = True if use_sigmoid is None else use_sigmoid + else: + self.smooth_label = self.original_smooth_label + use_sigmoid = False if use_sigmoid is None else use_sigmoid + + self.ce = CrossEntropyLoss( + use_sigmoid=use_sigmoid, + use_soft=not use_sigmoid, + reduction=reduction, + class_weight=class_weight, + pos_weight=pos_weight) + + def generate_one_hot_like_label(self, label): + """This function takes one-hot or index label vectors and computes one- + hot like label vectors (float)""" + # check if targets are inputted as class integers + if label.dim() == 1 or (label.dim() == 2 and label.shape[1] == 1): + label = convert_to_one_hot(label.view(-1, 1), self.num_classes) + return label.float() + + def original_smooth_label(self, one_hot_like_label): + assert self.num_classes > 0 + smooth_label = one_hot_like_label * (1 - self._eps) + smooth_label += self._eps / self.num_classes + return smooth_label + + def multilabel_smooth_label(self, one_hot_like_label): + assert self.num_classes > 0 + smooth_label = torch.full_like(one_hot_like_label, self._eps) + smooth_label.masked_fill_(one_hot_like_label > 0, 1 - self._eps) + return smooth_label + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + r"""Label smooth loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + label (torch.Tensor): The ground truth label of the prediction + with shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + if self.num_classes is not None: + assert self.num_classes == cls_score.shape[1], \ + f'num_classes should equal to cls_score.shape[1], ' \ + f'but got num_classes: {self.num_classes} and ' \ + f'cls_score.shape[1]: {cls_score.shape[1]}' + else: + self.num_classes = cls_score.shape[1] + + one_hot_like_label = self.generate_one_hot_like_label(label=label) + assert one_hot_like_label.shape == cls_score.shape, \ + f'LabelSmoothLoss requires output and target ' \ + f'to be same shape, but got output.shape: {cls_score.shape} ' \ + f'and target.shape: {one_hot_like_label.shape}' + + smoothed_label = self.smooth_label(one_hot_like_label) + return self.loss_weight * self.ce.forward( + cls_score, + smoothed_label, + weight=weight, + avg_factor=avg_factor, + reduction_override=reduction_override, + **kwargs) diff --git a/mmpretrain/models/losses/reconstruction_loss.py b/mmpretrain/models/losses/reconstruction_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..40e6bfd707b8e378f1ec656cfb443c27e8bbdbb3 --- /dev/null +++ b/mmpretrain/models/losses/reconstruction_loss.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class PixelReconstructionLoss(BaseModule): + """Loss for the reconstruction of pixel in Masked Image Modeling. + + This module measures the distance between the target image and the + reconstructed image and compute the loss to optimize the model. Currently, + This module only provides L1 and L2 loss to penalize the reconstructed + error. In addition, a mask can be passed in the ``forward`` function to + only apply loss on visible region, like that in MAE. + + Args: + criterion (str): The loss the penalize the reconstructed error. + Currently, only supports L1 and L2 loss + channel (int, optional): The number of channels to average the + reconstruction loss. If not None, the reconstruction loss + will be divided by the channel. Defaults to None. + """ + + def __init__(self, criterion: str, channel: Optional[int] = None) -> None: + super().__init__() + + if criterion == 'L1': + self.penalty = torch.nn.L1Loss(reduction='none') + elif criterion == 'L2': + self.penalty = torch.nn.MSELoss(reduction='none') + else: + raise NotImplementedError(f'Currently, PixelReconstructionLoss \ + only supports L1 and L2 loss, but get {criterion}') + + self.channel = channel if channel is not None else 1 + + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function to compute the reconstrction loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + loss = self.penalty(pred, target) + + # if the dim of the loss is 3, take the average of the loss + # along the last dim + if len(loss.shape) == 3: + loss = loss.mean(dim=-1) + + if mask is None: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() / self.channel + + return loss diff --git a/mmpretrain/models/losses/seesaw_loss.py b/mmpretrain/models/losses/seesaw_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4aaaa451b41ea7e86b7efbfe1c0b6ce8b3756d80 --- /dev/null +++ b/mmpretrain/models/losses/seesaw_loss.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# migrate from mmdetection with modifications +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import weight_reduce_loss + + +def seesaw_ce_loss(cls_score, + labels, + weight, + cum_samples, + num_classes, + p, + q, + eps, + reduction='mean', + avg_factor=None): + """Calculate the Seesaw CrossEntropy loss. + + Args: + cls_score (torch.Tensor): The prediction with shape (N, C), + C is the number of classes. + labels (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor): Sample-wise loss weight. + cum_samples (torch.Tensor): Cumulative samples for each category. + num_classes (int): The number of classes. + p (float): The ``p`` in the mitigation factor. + q (float): The ``q`` in the compenstation factor. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor + reduction (str, optional): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: The calculated loss + """ + assert cls_score.size(-1) == num_classes + assert len(cum_samples) == num_classes + + onehot_labels = F.one_hot(labels, num_classes) + seesaw_weights = cls_score.new_ones(onehot_labels.size()) + + # mitigation factor + if p > 0: + sample_ratio_matrix = cum_samples[None, :].clamp( + min=1) / cum_samples[:, None].clamp(min=1) + index = (sample_ratio_matrix < 1.0).float() + sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index + ) # M_{ij} + mitigation_factor = sample_weights[labels.long(), :] + seesaw_weights = seesaw_weights * mitigation_factor + + # compensation factor + if q > 0: + scores = F.softmax(cls_score.detach(), dim=1) + self_scores = scores[ + torch.arange(0, len(scores)).to(scores.device).long(), + labels.long()] + score_matrix = scores / self_scores[:, None].clamp(min=eps) + index = (score_matrix > 1.0).float() + compensation_factor = score_matrix.pow(q) * index + (1 - index) + seesaw_weights = seesaw_weights * compensation_factor + + cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) + + loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') + + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class SeesawLoss(nn.Module): + """Implementation of seesaw loss. + + Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) + `_ + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. + Only False is supported. Defaults to False. + p (float): The ``p`` in the mitigation factor. + Defaults to 0.8. + q (float): The ``q`` in the compenstation factor. + Defaults to 2.0. + num_classes (int): The number of classes. + Defaults to 1000 for the ImageNet dataset. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor, default to 1e-2. + reduction (str): The method that reduces the loss to a scalar. + Options are "none", "mean" and "sum". Defaults to "mean". + loss_weight (float): The weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + use_sigmoid=False, + p=0.8, + q=2.0, + num_classes=1000, + eps=1e-2, + reduction='mean', + loss_weight=1.0): + super(SeesawLoss, self).__init__() + assert not use_sigmoid, '`use_sigmoid` is not supported' + self.use_sigmoid = False + self.p = p + self.q = q + self.num_classes = num_classes + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + self.cls_criterion = seesaw_ce_loss + + # cumulative samples for each category + self.register_buffer('cum_samples', + torch.zeros(self.num_classes, dtype=torch.float)) + + def forward(self, + cls_score, + labels, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + cls_score (torch.Tensor): The prediction with shape (N, C). + labels (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + f'The `reduction_override` should be one of (None, "none", ' \ + f'"mean", "sum"), but get "{reduction_override}".' + assert cls_score.size(0) == labels.view(-1).size(0), \ + f'Expected `labels` shape [{cls_score.size(0)}], ' \ + f'but got {list(labels.size())}' + reduction = ( + reduction_override if reduction_override else self.reduction) + assert cls_score.size(-1) == self.num_classes, \ + f'The channel number of output ({cls_score.size(-1)}) does ' \ + f'not match the `num_classes` of seesaw loss ({self.num_classes}).' + + # accumulate the samples for each category + unique_labels = labels.unique() + for u_l in unique_labels: + inds_ = labels == u_l.item() + self.cum_samples[u_l] += inds_.sum() + + if weight is not None: + weight = weight.float() + else: + weight = labels.new_ones(labels.size(), dtype=torch.float) + + # calculate loss_cls_classes + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, labels, weight, self.cum_samples, self.num_classes, + self.p, self.q, self.eps, reduction, avg_factor) + + return loss_cls diff --git a/mmpretrain/models/losses/swav_loss.py b/mmpretrain/models/losses/swav_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c7dbb78e9bf6619cede65a874569072b863bdfa0 --- /dev/null +++ b/mmpretrain/models/losses/swav_loss.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from mmengine.dist import all_reduce +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@torch.no_grad() +def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int, + world_size: int, epsilon: float) -> torch.Tensor: + """Apply the distributed sinknorn optimization on the scores matrix to find + the assignments. + + This function is modified from + https://github.com/facebookresearch/swav/blob/main/main_swav.py + + Args: + out (torch.Tensor): The scores matrix + sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp + algorithm. + world_size (int): The world size of the process group. + epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. + + Returns: + torch.Tensor: Output of sinkhorn algorithm. + """ + eps_num_stab = 1e-12 + Q = torch.exp(out / epsilon).t( + ) # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + all_reduce(sum_Q) + Q /= sum_Q + + for it in range(sinkhorn_iterations): + # normalize each row: total weight per prototype must be 1/K + u = torch.sum(Q, dim=1, keepdim=True) + if len(torch.nonzero(u == 0)) > 0: + Q += eps_num_stab + u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype) + all_reduce(u) + Q /= u + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + +class MultiPrototypes(BaseModule): + """Multi-prototypes for SwAV head. + + Args: + output_dim (int): The output dim from SwAV neck. + num_prototypes (List[int]): The number of prototypes needed. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + output_dim: int, + num_prototypes: List[int], + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(num_prototypes, list) + self.num_heads = len(num_prototypes) + for i, k in enumerate(num_prototypes): + self.add_module('prototypes' + str(i), + nn.Linear(output_dim, k, bias=False)) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Run forward for every prototype.""" + out = [] + for i in range(self.num_heads): + out.append(getattr(self, 'prototypes' + str(i))(x)) + return out + + +@MODELS.register_module() +class SwAVLoss(BaseModule): + """The Loss for SwAV. + + This Loss contains clustering and sinkhorn algorithms to compute Q codes. + Part of the code is borrowed from `script + `_. + The queue is built in `engine/hooks/swav_hook.py`. + + Args: + feat_dim (int): feature dimension of the prototypes. + sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp + algorithm. Defaults to 3. + epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. + Defaults to 0.05. + temperature (float): temperature parameter in training loss. + Defaults to 0.1. + crops_for_assign (List[int]): list of crops id used for computing + assignments. Defaults to [0, 1]. + num_crops (List[int]): list of number of crops. Defaults to [2]. + num_prototypes (int): number of prototypes. Defaults to 3000. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + feat_dim: int, + sinkhorn_iterations: int = 3, + epsilon: float = 0.05, + temperature: float = 0.1, + crops_for_assign: List[int] = [0, 1], + num_crops: List[int] = [2], + num_prototypes: int = 3000, + init_cfg: Optional[Union[List[dict], dict]] = None): + super().__init__(init_cfg=init_cfg) + self.sinkhorn_iterations = sinkhorn_iterations + self.epsilon = epsilon + self.temperature = temperature + self.crops_for_assign = crops_for_assign + self.num_crops = num_crops + self.use_queue = False + self.queue = None + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # prototype layer + self.prototypes = None + if isinstance(num_prototypes, list): + self.prototypes = MultiPrototypes(feat_dim, num_prototypes) + elif num_prototypes > 0: + self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False) + assert self.prototypes is not None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function of SwAV loss. + + Args: + x (torch.Tensor): NxC input features. + Returns: + torch.Tensor: The returned loss. + """ + # normalize the prototypes + with torch.no_grad(): + w = self.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + self.prototypes.weight.copy_(w) + + embedding, output = x, self.prototypes(x) + embedding = embedding.detach() + + bs = int(embedding.size(0) / sum(self.num_crops)) + loss = 0 + for i, crop_id in enumerate(self.crops_for_assign): + with torch.no_grad(): + out = output[bs * crop_id:bs * (crop_id + 1)].detach() + # time to use the queue + if self.queue is not None: + if self.use_queue or not torch.all(self.queue[i, + -1, :] == 0): + self.use_queue = True + out = torch.cat( + (torch.mm(self.queue[i], + self.prototypes.weight.t()), out)) + # fill the queue + self.queue[i, bs:] = self.queue[i, :-bs].clone() + self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) * + bs] + + # get assignments (batch_size * num_prototypes) + q = distributed_sinkhorn(out, self.sinkhorn_iterations, + self.world_size, self.epsilon)[-bs:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id): + x = output[bs * v:bs * (v + 1)] / self.temperature + subloss -= torch.mean( + torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1)) + loss += subloss / (np.sum(self.num_crops) - 1) + loss /= len(self.crops_for_assign) + return loss diff --git a/mmpretrain/models/losses/utils.py b/mmpretrain/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a65b68a6590aa3fe10a023022c9c9c9bce51f935 --- /dev/null +++ b/mmpretrain/models/losses/utils.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + ``loss_func(pred, target, **kwargs)``. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like ``loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)``. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + """This function converts target class indices to one-hot vectors, given + the number of classes. + + Args: + targets (Tensor): The ground truth label of the prediction + with shape (N, 1) + classes (int): the number of classes. + + Returns: + Tensor: Processed loss values. + """ + assert (torch.max(targets).item() < + classes), 'Class Index must be less than number of classes' + one_hot_targets = F.one_hot( + targets.long().squeeze(-1), num_classes=classes) + return one_hot_targets diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..072c0f84f723e1ce5f7b3efbda774e6c80f74063 --- /dev/null +++ b/mmpretrain/models/multimodal/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL + +if WITH_MULTIMODAL: + from .blip import * # noqa: F401,F403 + from .blip2 import * # noqa: F401,F403 + from .chinese_clip import * # noqa: F401, F403 + from .flamingo import * # noqa: F401, F403 + from .llava import * # noqa: F401, F403 + from .minigpt4 import * # noqa: F401, F403 + from .ofa import * # noqa: F401, F403 + from .otter import * # noqa: F401, F403 +else: + from mmpretrain.registry import MODELS + from mmpretrain.utils.dependency import register_multimodal_placeholder + + register_multimodal_placeholder([ + 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', + 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter' + ], MODELS) diff --git a/mmpretrain/models/multimodal/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e431010711dfafa21c6faae7810c5b2a8d2138cc Binary files /dev/null and b/mmpretrain/models/multimodal/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__init__.py b/mmpretrain/models/multimodal/blip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbc0da6e0d11c116d4575b6c981724e387e415a --- /dev/null +++ b/mmpretrain/models/multimodal/blip/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blip_caption import BlipCaption +from .blip_grounding import BlipGrounding +from .blip_nlvr import BlipNLVR +from .blip_retrieval import BlipRetrieval +from .blip_vqa import BlipVQA +from .language_model import BertLMHeadModel, XBertEncoder, XBertLMHeadDecoder + +__all__ = [ + 'BertLMHeadModel', 'BlipCaption', 'BlipGrounding', 'BlipNLVR', + 'BlipRetrieval', 'BlipVQA', 'XBertEncoder', 'XBertLMHeadDecoder' +] diff --git a/mmpretrain/models/multimodal/blip/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d848e7961886f35fc92bd175a1a3f98098725e9a Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__pycache__/blip_caption.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/blip_caption.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51812933b0700a6ae6dbd1906bfaf46618fc00b1 Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/blip_caption.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__pycache__/blip_grounding.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/blip_grounding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2317a3732deee6fbc4835c7f616fc57ec4dc013e Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/blip_grounding.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__pycache__/blip_nlvr.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/blip_nlvr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d59f8c7a7c18295d93984af81aa18ec9e7b6addb Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/blip_nlvr.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__pycache__/blip_retrieval.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/blip_retrieval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f041668d4dac102518510abcf0f339efbacdeef5 Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/blip_retrieval.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__pycache__/blip_vqa.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/blip_vqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf71bda8c732454071f8dd28fac37a164f40a5d8 Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/blip_vqa.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__pycache__/language_model.cpython-38.pyc b/mmpretrain/models/multimodal/blip/__pycache__/language_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3136bac9c1dac76ec3c14a5ea22383067f7669b5 Binary files /dev/null and b/mmpretrain/models/multimodal/blip/__pycache__/language_model.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip/blip_caption.py b/mmpretrain/models/multimodal/blip/blip_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..9af3e2408da8c6b3a55694a1323e6434dfc609e1 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_caption.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class BlipCaption(BaseModel): + """BLIP Caption. + + Args: + vision_encoder (dict): Encoder for extracting image features. + decoder_head (dict): The decoder head module to forward and + calculate loss from processed features. + tokenizer: (Optional[dict]): The config for tokenizer. + Defaults to None. + prompt (str): Prompt used for training and eval. + Defaults to ''. + max_txt_len (int): Max text length of input text. + num_captions (int): Number of captions to be generated for each image. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_encoder: dict, + decoder_head: dict, + tokenizer: Optional[dict] = None, + prompt: str = '', + max_txt_len: int = 20, + num_captions: int = 1, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipCaption, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.visual_encoder = MODELS.build(vision_encoder) + self.seq_gen_head = MODELS.build(decoder_head) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + self.max_txt_len = max_txt_len + self.num_captions = num_captions + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept two modes: "predict" and "loss": + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): pre_processed img tensor (N, C, ...). + data_samples (List[DataSample], optional): Data samples with + additional infos. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, images, data_samples=None, **kwargs): + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # prepare inputs for decoder generation. + image_embeds = self.visual_encoder(images)[0] + image_embeds = torch.repeat_interleave(image_embeds, self.num_captions, + 0) + + prompt = [self.prompt] * image_embeds.size(0) + prompt = self.tokenizer( + prompt, padding='longest', + return_tensors='pt').to(image_embeds.device) + + prompt.input_ids[:, 0] = self.tokenizer.bos_token_id + prompt.input_ids = prompt.input_ids[:, :-1] + + decoder_out = self.seq_gen_head.predict( + input_ids=prompt.input_ids, + encoder_hidden_states=image_embeds, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + output_attentions=True, + return_dict_in_generate=True, + ) + + decode_tokens = self.tokenizer.batch_decode( + decoder_out.sequences, skip_special_tokens=True) + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(len(decode_tokens))] + + for data_sample, decode_token in zip(data_samples, decode_tokens): + if data_sample is None: + data_sample = DataSample() + data_sample.pred_caption = decode_token[len(self.prompt):] + out_data_samples.append(data_sample) + + return out_data_samples + + def loss(self, images, data_samples): + """Calculate losses from a batch of images and data samples. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[ImageTextDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + image_embeds = self.visual_encoder(images)[0] + raw_text = [self.prompt + ds.gt_caption for ds in data_samples] + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(image_embeds.device) + text.input_ids[:, 0] = self.tokenizer.bos_token_id + + # prepare targets for forwarding decoder + labels = text.input_ids.masked_fill( + text.input_ids == self.tokenizer.pad_token_id, -100) + labels[:, :self.prompt_length] = -100 + # forward decoder + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + + losses = self.seq_gen_head.loss( + input_ids=text.input_ids, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + labels=labels, + ) + return losses diff --git a/mmpretrain/models/multimodal/blip/blip_grounding.py b/mmpretrain/models/multimodal/blip/blip_grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..cb087287220a91b3bfcd50acee244eb5dc118bac --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_grounding.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.model import BaseModel + +from mmpretrain.models.utils.box_utils import box_xyxy_to_cxcywh +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures.data_sample import DataSample + + +@MODELS.register_module() +class BlipGrounding(BaseModel): + """BLIP Grounding. + + Args: + visual_encoder (dict): Backbone for extracting image features. + text_encoder (dict): Backbone for extracting text features. + but we integrate the vqa text extractor + into the tokenizer part in datasets/transform/ + so we don't need text_backbone + multimodal_encoder (Optional[dict]): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + neck (Optional[dict]): The neck module to process features from + backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + tokenizer: Optional[dict] = None, + visual_encoder: Optional[dict] = None, + text_encoder: Optional[dict] = None, + multimodal_encoder: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipGrounding, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.prompt = 'localize instance: ' + self.visual_encoder = MODELS.build(visual_encoder) + self.text_encoder = MODELS.build(text_encoder) + self.multimodal_encoder = MODELS.build(multimodal_encoder) + head.setdefault('tokenizer', self.tokenizer) + self.grounding_head = MODELS.build(head) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[VQADataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: + image_embeds (Tensor): The output features. + """ + image_embeds = self.visual_encoder(images)[0] + return image_embeds + + def loss( + self, + images: torch.Tensor, + data_samples=None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """generate train_loss from the input tensor and data_samples. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + data_samples (List[VQADataSample], optional): The annotation + data of every samples.. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + + # extract image feature + image_embeds = self.extract_feat(images) + image_atts = image_embeds.new_ones( + image_embeds.size()[:-1], dtype=torch.long) + + raw_text = [] + box_targets = [] + for ds in data_samples: + + raw_text.append(ds.text) + box_t = copy.deepcopy(ds.box) * 1.0 + box_t[1] /= ds.img_shape[0] + box_t[3] /= ds.img_shape[0] + box_t[0] /= ds.img_shape[1] + box_t[2] /= ds.img_shape[1] + + box_targets.append(box_t) + + box_targets = image_embeds.new_tensor(np.stack(box_targets)) + box_targets = box_xyxy_to_cxcywh(box_targets) # xywh 0-1 + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=128, + return_tensors='pt', + ).to(image_embeds.device) + + text_embeds = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + mode='text', + return_dict=True) # bz, seq_len, hid + + # multimodal fusion + multimodal_embeds = self.multimodal_encoder( + encoder_embeds=text_embeds.last_hidden_state, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + losses = self.grounding_head.loss( + text_embedding=multimodal_embeds.last_hidden_state, + text_embedding_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + decoder_targets=box_targets, + ) + + return losses + + def predict(self, images, data_samples=None): + """""" + + # extract image feature + image_embeds = self.extract_feat(images) + image_atts = image_embeds.new_ones( + image_embeds.size()[:-1], dtype=torch.long) + + raw_text = [] + for ds in data_samples: + raw_text.append(ds.text) + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=128, + return_tensors='pt', + ).to(image_embeds.device) + + text_embeds = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + mode='text', + return_dict=True) # bz, seq_len, hid + + # multimodal fusion + multimodal_embeds = self.multimodal_encoder( + encoder_embeds=text_embeds.last_hidden_state, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + output_boxes = self.grounding_head.predict( + text_embedding=multimodal_embeds.last_hidden_state, + text_embedding_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + ) # xyxy 0-1 + + out_data_samples = [] + for bbox, data_sample, img in zip(output_boxes, data_samples, images): + if data_sample is None: + data_sample = DataSample() + + img_size = img.shape[-2:] + scale_factor = data_sample.get('scale_factor', (1, 1)) + bbox[0::2] = bbox[0::2] * img_size[1] / scale_factor[0] + bbox[1::2] = bbox[1::2] * img_size[0] / scale_factor[1] + bbox = bbox[None, :] + data_sample.pred_bboxes = bbox + + if 'gt_bboxes' in data_sample: + gt_bboxes = torch.Tensor(data_sample.get('gt_bboxes')) + gt_bboxes[:, 0::2] /= scale_factor[0] + gt_bboxes[:, 1::2] /= scale_factor[1] + data_sample.gt_bboxes = gt_bboxes + + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/blip/blip_nlvr.py b/mmpretrain/models/multimodal/blip/blip_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..f96e3cce237fd3b064c74264e8f907a8bd3a47ca --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_nlvr.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER + + +@MODELS.register_module() +class BlipNLVR(BaseModel): + """BLIP NLVR. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + but we integrate the vqa text extractor into the tokenizer part in + datasets/transform/ so we don't need text_backbone + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + neck (Optional[dict]): The neck module to process features from + backbone. Defaults to None. + head (Optional[dict]): The head module to calculate + loss from processed features. See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + tokenizer: (Optional[dict]): The config for tokenizer + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + multimodal_backbone: dict, + tokenizer: Optional[dict] = None, + max_txt_len: int = 35, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + if tokenizer is not None: + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_backbone = MODELS.build(vision_backbone) + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.max_txt_len = max_txt_len + + # For simplity, directly use head definition here. + # If more complex head is designed, move this and loss to a new + # head module. + hidden_size = self.multimodal_backbone.config.hidden_size + self.head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, 2), + ) + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_text(self, data_samples): + + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + texts = [sample.get('text') for sample in data_samples] + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def forward( + self, + images: dict, + data_samples: Optional[List] = None, + mode: str = 'tensor', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + images and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (dict of torch.Tensor): + img: pre_processed img tensor (N, C, ...). + text: tokenized text (N, L) + data_samples (List[CaptionDataSample], optional): + The annotation data of every samples. + 'image': raw image data + 'text' tokenized text + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + # B, T, C, H, W to T*B, C, H, W + images = images.permute(1, 0, 2, 3, 4).flatten(0, 1) + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, images, data_samples=None): + """Predict caption.""" + # prepare inputs for decoder generation. + image_embeds = self.vision_backbone(images)[0] + texts = self.preprocess_text(data_samples) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + image0_embeds, image1_embeds = torch.split(image_embeds, + texts.input_ids.size(0)) + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + texts.input_ids, + attention_mask=texts.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):], + ], + return_dict=True, + ) + + # get prediction + outputs = self.head(multimodal_embeds.last_hidden_state[:, 0, :]) + + pred_scores = F.softmax(outputs, dim=1) + + for pred_score, data_sample in zip(pred_scores, data_samples): + data_sample.set_pred_score(pred_score) + data_sample.set_pred_label(pred_score.argmax(dim=0)) + + return data_samples + + def loss(self, images, data_samples): + """Calculate losses from a batch of inputs and data samples. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[ImageTextDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # prepare inputs for decoder generation. + image_embeds = self.vision_backbone(images)[0] + texts = self.preprocess_text(data_samples) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + image0_embeds, image1_embeds = torch.split(image_embeds, + texts.input_ids.size(0)) + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + texts.input_ids, + attention_mask=texts.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):], + ], + return_dict=True, + ) + + # get prediction + outputs = self.head(multimodal_embeds.last_hidden_state[:, 0, :]) + + targets = torch.tensor([i.gt_label + for i in data_samples]).to(outputs.device) + loss = F.cross_entropy(outputs, targets) + return {'loss': loss} diff --git a/mmpretrain/models/multimodal/blip/blip_retrieval.py b/mmpretrain/models/multimodal/blip/blip_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..8983e63e20832fa2e9b36e39134b6fe748baab61 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_retrieval.py @@ -0,0 +1,716 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import ChainMap +from copy import deepcopy +from typing import Dict, List, Optional, Tuple, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import distributed as torch_dist + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process + + +def all_gather_concat(data: torch.Tensor) -> torch.Tensor: + """Gather tensors with different first-dimension size and concat to one + tenosr. + + Note: + Only the first dimension should be different. + + Args: + data (Tensor): Tensor to be gathered. + + Returns: + torch.Tensor: The concatenated tenosr. + """ + if dist.get_world_size() == 1: + return data + + data_size = torch.tensor(data.size(0), device=data.device) + sizes_list = dist.all_gather(data_size) + + max_length = max(sizes_list) + size_diff = max_length.item() - data_size.item() + if size_diff: + padding = torch.zeros( + size_diff, *data.size()[1:], device=data.device, dtype=data.dtype) + data = torch.cat((data, padding)) + + gather_list = dist.all_gather(data) + + all_data = [] + for tensor, size in zip(gather_list, sizes_list): + + all_data.append(tensor[:size]) + + return torch.concat(all_data) + + +@MODELS.register_module() +class BlipRetrieval(BaseModel): + """BLIP Retriever. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. + vision_neck (Optional[dict]): The neck module to process image features + from vision backbone. Defaults to None. + text_neck (Optional[dict]): The neck module to process text features + from text backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed single modality features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal + head module to calculate loss from processed multimodal features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + momentum (float): Momentum used for momentum contrast. + Defaults to .995. + negative_all_rank (bool): Whether to sample negative data from all + ranks for image text matching in training. Defaults to True. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + topk (int): Select topk similarity as candidates for compute matching + scores. Notice that this is not the topk in evaluation. + Defaults to 256. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + multimodal_backbone: Optional[dict] = None, + vision_neck: Optional[dict] = None, + text_neck: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + multimodal_head: Optional[Union[List[dict], dict]] = None, + tokenizer: Optional[dict] = None, + momentum: float = .995, + negative_all_rank: bool = True, + temperature: float = 0.07, + fast_match: bool = False, + topk: int = 256, + max_txt_len: int = 20, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.vision_backbone = MODELS.build(vision_backbone) + self.text_backbone = MODELS.build(text_backbone) + + if multimodal_backbone is not None: + self.multimodal_backbone = MODELS.build(multimodal_backbone) + + if vision_neck is not None: + self.vision_neck = MODELS.build(vision_neck) + + if text_neck is not None: + self.text_neck = MODELS.build(text_neck) + + if head is not None: + self.head = MODELS.build(head) + + if multimodal_head is not None: + self.multimodal_head = MODELS.build(multimodal_head) + + if tokenizer is not None: + self.tokenizer = TOKENIZER.build(tokenizer) + + self.momentum = momentum + self.negative_all_rank = negative_all_rank + self.temp = nn.Parameter(temperature * torch.ones([])) + # Shares the same para + self.head.temp = self.temp + + # create the momentum encoder + self.vision_backbone_m = deepcopy(self.vision_backbone) + self.text_backbone_m = deepcopy(self.text_backbone) + + self.vision_neck_m = deepcopy(self.vision_neck) + self.text_neck_m = deepcopy(self.text_neck) + + self.model_pairs = [ + [self.vision_backbone, self.vision_backbone_m], + [self.text_backbone, self.text_backbone_m], + [self.vision_neck, self.vision_neck_m], + [self.text_neck, self.text_neck_m], + ] + self.copy_params() + + # multimodal backone shares weights with text backbone in BLIP + # No need to set up + + # Notice that this topk is used for select k candidate to compute + # image-text score, but not the final metric topk in evaluation. + self.fast_match = fast_match + self.topk = topk + + self.max_txt_len = max_txt_len + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_text(self, data_samples): + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + if isinstance(sample_item.get('text'), (list, tuple)): + texts = [] + for sample in data_samples: + texts.extend(sample.get('text')) + elif isinstance(sample_item.get('text'), str): + texts = [sample.get('text') for sample in data_samples] + else: + raise TypeError('text must be a string or a list of strings') + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='max_length', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def forward(self, + images: torch.tensor = None, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor') -> Union[Tuple, dict]: + """The unified entry for a forward process in both training and test. + The method should accept two modes: "tensor", and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + For unified "predict" mode in other mm repos. It is noticed that + image-text retrieval cannot perform batch prediction since it will go + through all the samples. A standard process of retrieval evaluation is + to extract and collect all feats, and then predict all samples. + Therefore the `predict` mode here is remained as a trigger + to inform use to choose the right configurations. + + Args: + images (torch.Tensor): The input inputs tensor of shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + - If ``mode="tensor"``, return a tuple. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + return self.extract_feat(images, data_samples) + elif mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat( + self, + images: torch.Tensor = None, + data_samples: List[DataSample] = None, + return_texts=True, + return_embeds=None, + ) -> Dict[str, torch.Tensor]: + """Extract features from the input dict. + + Args: + images (tensor, optional): The images to extract features. + Defaults to None. + data_samples (list, optional): The data samples containing texts + to extract features. Defaults to None. + return_texts (bool): Whether to return the tokenized text and the + corresponding attention masks. Defaults to True. + return_embeds (bool): Whether to return the text embedding and + image embedding. Defaults to None, which means to use + ``self.fast_match``. + + Returns: + Tuple[torch.Tensor]: The output features. + If multimodal_backbone is not exist, tuple of torch.Tensor + will be returned. + """ + if data_samples is not None: + texts = self.preprocess_text(data_samples) + else: + texts = None + + assert images is not None or texts is not None, \ + 'At least single modality should be passed as inputs.' + + results = {} + if texts is not None and return_texts: + results.update({ + 'text_ids': texts.input_ids, + 'text_attn_mask': texts.attention_mask, + }) + + if return_embeds is None: + return_embeds = not self.fast_match + + # extract image features + if images is not None: + output = self._extract_feat(images, modality='images') + results['image_feat'] = output['image_feat'] + if return_embeds: + results['image_embeds'] = output['image_embeds'] + + # extract text features + if texts is not None: + output = self._extract_feat(texts, modality='texts') + results['text_feat'] = output['text_feat'] + if return_embeds: + results['text_embeds'] = output['text_embeds'] + + return results + + def _extract_feat(self, inputs: Union[torch.Tensor, dict], + modality: str) -> Tuple[torch.Tensor]: + """Extract features from the single modality. + + Args: + inputs (Union[torch.Tensor, dict]): A batch of inputs. + For image, a tensor of shape (N, C, ...) in general. + For text, a dict of tokenized text inputs. + modality (str): Modality feature to be extracted. Only two + options are supported. + + - ``images``: Only extract image features, mostly used for + inference. + - ``texts``: Only extract text features, mostly used for + inference. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + + if modality == 'images': + # extract image features + image_embeds = self.vision_backbone(inputs)[0] + image_feat = F.normalize( + self.vision_neck(image_embeds[:, 0, :]), dim=-1) + return {'image_embeds': image_embeds, 'image_feat': image_feat} + elif modality == 'texts': + # extract text features + text_output = self.text_backbone( + inputs.input_ids, + attention_mask=inputs.attention_mask, + token_type_ids=None, + return_dict=True, + mode='text', + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_neck(text_embeds[:, 0, :]), dim=-1) + return {'text_embeds': text_embeds, 'text_feat': text_feat} + else: + raise RuntimeError(f'Invalid modality "{modality}".') + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + both head and multimodal head. + """ + output = self.extract_feat(images, data_samples, return_embeds=True) + + text_ids = output['text_ids'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.vision_backbone_m(images)[0] + image_feat_m = F.normalize( + self.vision_neck_m(image_embeds_m[:, 0, :]), dim=-1) + + text_output_m = self.text_backbone_m( + text_ids, + attention_mask=text_attn_mask, + token_type_ids=None, + return_dict=True, + mode='text', + ) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize( + self.text_neck_m(text_embeds_m[:, 0, :]), dim=-1) + + loss = self.head.loss( + ([image_feat, text_feat, image_feat_m, text_feat_m], ), + data_samples) + + # prepare for itm + encoder_input_ids = text_ids.clone() + encoder_input_ids[:, + 0] = self.tokenizer.additional_special_tokens_ids[0] + output_pos = self.text_backbone( + encoder_input_ids, + attention_mask=text_attn_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + idx = torch.tensor([i.image_id for i in data_samples]).view(-1, 1) + bs = idx.size(0) + idxs = torch.cat(dist.all_gather(idx)) + if self.negative_all_rank: + # compute sample similarity + with torch.no_grad(): + mask = torch.eq(idx, idxs.t()).to(self.device) + + image_feat_world = torch.cat(dist.all_gather(image_feat)) + text_feat_world = torch.cat(dist.all_gather(text_feat)) + + sim_i2t = image_feat @ text_feat_world.t() / self.temp + sim_t2i = text_feat @ image_feat_world.t() / self.temp + + weights_i2t = F.softmax(sim_i2t, dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i, dim=1) + weights_t2i.masked_fill_(mask, 0) + + world_size = dist.get_world_size() + if world_size == 1: + image_embeds_world = image_embeds + else: + image_embeds_world = torch.cat( + torch_dist.nn.all_gather(image_embeds)) + + # select a negative image (from all ranks) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text (from all ranks) for each image + input_ids_world = torch.cat(dist.all_gather(encoder_input_ids)) + att_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(input_ids_world[neg_idx]) + text_atts_neg.append(att_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0) + text_atts_all = torch.cat([text_attn_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + output_neg = self.text_backbone( + text_ids_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = torch.cat( + [ + output_pos.last_hidden_state[:, 0, :], + output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + + # create false data samples + data_samples.extend( + [DataSample(is_matched=False) for _ in range(2 * bs)]) + loss_multimodal = self.multimodal_head.loss((vl_embeddings, ), + data_samples) + + return dict(ChainMap(loss, loss_multimodal)) + + def predict(self, images, data_samples, cal_i2t=True, cal_t2i=True): + feats = self.extract_feat(images, data_samples) + + return self.predict_all( + feats, data_samples, cal_i2t=cal_i2t, cal_t2i=cal_t2i) + + def predict_all(self, + feats, + data_samples, + num_images=None, + num_texts=None, + cal_i2t=True, + cal_t2i=True): + text_ids = feats['text_ids'] + text_ids[:, 0] = self.tokenizer.additional_special_tokens_ids[0] + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + if not self.fast_match: + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + else: + image_embeds_all = None + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_ids_all = all_gather_concat(text_ids)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_ids_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_ids, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats, img_embeds, text_feats, + text_ids, text_atts): + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + + # compute i2t sim matrix + sim_matrix_i2t = img_feats @ text_feats.t() + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(img_feats.size(0)), 'Compute I2T scores...'): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + + encoder_output = img_embeds[i].repeat(self.topk, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + output = self.text_backbone( + text_ids[topk_idx], + attention_mask=text_atts[topk_idx], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, 0, :], ))[:, 1] + score_matrix_i2t[i, topk_idx] = score + topk_sim + + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats, img_embeds, text_feats, + text_ids, text_atts): + """Compare the score matrix for text-to-image retrieval. Every text + should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + + # compute t2i sim matrix + sim_matrix_t2i = text_feats @ img_feats.t() + if self.fast_match: + return sim_matrix_t2i + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(text_feats.size(0)), 'Compute T2I scores...'): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + + encoder_output = img_embeds[topk_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + output = self.text_backbone( + text_ids[i].repeat(self.topk, 1), + attention_mask=text_atts[i].repeat(self.topk, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, 0, :], ))[:, 1] + score_matrix_t2i[i, topk_idx] = score + topk_sim + + return score_matrix_t2i + + def _get_predictions(self, + result: torch.Tensor, + data_samples: List[DataSample], + mode: str = 'i2t'): + """Post-process the output of retriever. + + Args: + result (torch.Tensor): Score matrix of single retrieve, + either from image or text. + data_samples (List[DataSample], optional): The annotation + data of every samples. + mode (str): Retrieve mode, either `i2t` for image to text, or `t2i` + text to image. Defaults to `i2t`. + + Returns: + List[DataSample]: the raw data_samples with + the predicted results. + """ + + # create data sample if not exists + if data_samples is None: + data_samples = [DataSample() for _ in range(result.size(0))] + elif mode == 't2i': + # Process data samples to align with the num of texts. + new_data_samples = [] + for sample in data_samples: + if isinstance(sample.text, (list, tuple)): + texts = sample.text + else: + texts = [sample.text] + for i, text in enumerate(texts): + new_sample = DataSample(text=text) + if 'gt_image_id' in sample: + new_sample.gt_label = sample.gt_image_id[i] + new_data_samples.append(new_sample) + assert len(new_data_samples) == result.size(0) + data_samples = new_data_samples + elif mode == 'i2t': + for sample in data_samples: + if 'gt_text_id' in sample: + sample.gt_label = sample.gt_text_id + else: + raise ValueError(f'Type {mode} is not supported.') + + for data_sample, score in zip(data_samples, result): + idx = score.argmax(keepdim=True).detach() + + data_sample.set_pred_score(score) + data_sample.set_pred_label(idx) + return data_samples + + # TODO: add temperaily + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for (name, + param), (name_m, + param_m) in zip(model_pair[0].named_parameters(), + model_pair[1].named_parameters()): + # hack to behave the same + if any([i in name for i in ['8', '9', '10', '11'] + ]) and 'layers' in name and any( + [i in name for i in ['attn', 'ffn']]): + param_m.data = param.data + else: + param_m.data = param_m.data * self.momentum + \ + param.data * (1.0 - self.momentum) diff --git a/mmpretrain/models/multimodal/blip/blip_vqa.py b/mmpretrain/models/multimodal/blip/blip_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f4e5861b5c92be302cc48eaa7a37264be63f93 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_vqa.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class BlipVQA(BaseModel): + """BLIP VQA. + + Args: + tokenizer: (dict): The config for tokenizer. + vision_backbone (dict): Encoder for extracting image features. + multimodal_backbone (dict): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + head (dict): The head module to calculate + loss from processed features. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + `MutimodalDataPreprocessor` as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + multimodal_backbone: dict, + head: dict, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipVQA, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_backbone = MODELS.build(vision_backbone) + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.vqa_head = MODELS.build(head) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + + - "loss": For training. Forward and return a dict of losses according + to the given inputs and data samples. Note that this method doesn't + handle neither back propagation nor optimizer updating, which are + done in the :meth:`train_step`. + - "predict": For testing. Forward and return a list of data_sample that + contains pred_answer for each question. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation data of + every samples. Required when ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + - If ``mode="predict"``, return a list of `DataSample` + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from the input tensor with shape (N, C, ..). + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + + Returns: + visual_embeds (Tensor): The output features. + """ + # extract visual feature + if images.ndim == 4: + visual_embeds = self.vision_backbone(images)[0] + elif images.ndim == 5: + # [batch, T, C, H, W] -> [batch * T, C, H, W] + bs = images.size(0) + images = images.reshape(-1, *images.shape[2:]) + visual_embeds = self.vision_backbone(images)[0] + # [batch * num_segs, L, dim] -> [batch, num_segs * L, dim] + visual_embeds = visual_embeds.reshape(bs, -1, + *visual_embeds.shape[2:]) + else: + raise ValueError( + f'Images with {images.ndim} dims is not supported.') + return visual_embeds + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """generate train_loss from the input tensor and data_samples. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + visual_embeds = self.extract_feat(images) + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + questions = [] + for sample in data_samples: + questions.append(sample.get('question')) + questions = self.tokenizer( + questions, padding='longest', return_tensors='pt').to(self.device) + + questions.input_ids[:, 0] = \ + self.tokenizer.additional_special_tokens_ids[0] + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + answer_raw_text = [] + for sample in data_samples: + answer_raw_text.extend(sample.gt_answer) + answer = self.tokenizer( + answer_raw_text, padding='longest', + return_tensors='pt').to(self.device) + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + for sample in data_samples: + # follow BLIP setting, set answer_weight to 0.2 for VG dataset. + if not hasattr(sample, 'gt_answer_weight'): + sample.gt_answer_weight = torch.tensor([0.2]) + else: + sample.gt_answer_weight = torch.tensor(sample.gt_answer_weight) + answer_weight = torch.cat( + [sample.gt_answer_weight for sample in data_samples], + dim=0).to(self.device) + answer_count = torch.tensor( + [len(sample.gt_answer) for sample in data_samples]).to(self.device) + + question_states, question_atts = [], [] + for b, n in enumerate(answer_count): + question_states += [multimodal_embeds.last_hidden_state[b]] * n + question_atts += [questions.attention_mask[b]] * n + + question_states = torch.stack(question_states, dim=0).to(self.device) + question_atts = torch.stack(question_atts, dim=0).to(self.device) + + head_feats = dict( + answer_input_ids=answer.input_ids, + answer_attention_mask=answer.attention_mask, + answer_weight=answer_weight, + answer_targets=answer_targets, + question_states=question_states, + question_atts=question_atts, + batch_size=len(data_samples), + ) + + losses = self.vqa_head.loss(head_feats) + + return losses + + def predict( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ): + """update data_samples that contain pred_answer for each question. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + visual_embeds = self.extract_feat(images) + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + questions = [] + for sample in data_samples: + questions.append(sample.get('question')) + questions = self.tokenizer( + questions, padding='longest', return_tensors='pt').to(self.device) + + questions.input_ids[:, 0] = \ + self.tokenizer.additional_special_tokens_ids[0] + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + if self.vqa_head.inference_method == 'rank': + answer_candidates = self.tokenizer( + self.vqa_head.answer_list, + padding='longest', + return_tensors='pt').to(self.device) + answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id + elif self.vqa_head.inference_method == 'generate': + answer_candidates = None + + head_feats = dict( + multimodal_embeds=multimodal_embeds.last_hidden_state, + question_atts=questions.attention_mask, + answer_candidates=answer_candidates, + bos_token_id=self.tokenizer.bos_token_id, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + if self.vqa_head.inference_method == 'rank': + answers = self.vqa_head.predict(head_feats) + for answer, data_sample in zip(answers, data_samples): + data_sample.pred_answer = answer + + elif self.vqa_head.inference_method == 'generate': + outputs = self.vqa_head.predict(head_feats) + for output, data_sample in zip(outputs, data_samples): + data_sample.pred_answer = self.tokenizer.decode( + output, skip_special_tokens=True) + + return data_samples diff --git a/mmpretrain/models/multimodal/blip/language_model.py b/mmpretrain/models/multimodal/blip/language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..48605a95f60550e970f893f55c4a43e03efb74df --- /dev/null +++ b/mmpretrain/models/multimodal/blip/language_model.py @@ -0,0 +1,1320 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# flake8: noqa + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from torch import Tensor, device + +try: + from transformers.activations import ACT2FN + from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) + from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + from transformers.models.bert.configuration_bert import BertConfig +except: + ACT2FN = None + BaseModelOutputWithPastAndCrossAttentions = None + BaseModelOutputWithPoolingAndCrossAttentions = None + CausalLMOutputWithCrossAttentions = None + PreTrainedModel = None + apply_chunking_to_forward = None + find_pruneable_heads_and_indices = None + prune_linear_layer = None + BertConfig = None + +from mmpretrain.registry import MODELS + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + if config.add_type_embeddings: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + else: + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = ( + attention_scores + relative_position_scores_query + + relative_position_scores_key) + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ((context_layer, attention_probs) if output_attentions else + (context_layer, )) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config, twin=False, merge=False): + super().__init__() + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if twin: + self.dense0 = nn.Linear(config.hidden_size, config.hidden_size) + self.dense1 = nn.Linear(config.hidden_size, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if merge: + self.act = ACT2FN[config.hidden_act] + self.merge_layer = nn.Linear(config.hidden_size * 2, + config.hidden_size) + self.merge = True + else: + self.merge = False + + def forward(self, hidden_states, input_tensor): + if type(hidden_states) == list: + hidden_states0 = self.dense0(hidden_states[0]) + hidden_states1 = self.dense1(hidden_states[1]) + if self.merge: + hidden_states = self.merge_layer( + torch.cat([hidden_states0, hidden_states1], dim=-1)) + else: + hidden_states = (hidden_states0 + hidden_states1) / 2 + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False, layer_num=-1): + super().__init__() + is_nlvr = is_cross_attention and getattr(config, 'nlvr', False) + if is_nlvr: + self.self0 = BertSelfAttention(config, is_nlvr) + self.self1 = BertSelfAttention(config, is_nlvr) + else: + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput( + config, + twin=is_nlvr, + merge=(is_nlvr and layer_num >= 6), + ) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + if type(encoder_hidden_states) == list: + self_outputs0 = self.self0( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[0], + encoder_attention_mask[0], + past_key_value, + output_attentions, + ) + self_outputs1 = self.self1( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[1], + encoder_attention_mask[1], + past_key_value, + output_attentions, + ) + attention_output = self.output( + [self_outputs0[0], self_outputs1[0]], hidden_states) + + outputs = (attention_output, ) + self_outputs0[ + 1:] # add attentions if we output them + else: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + + # compatibility for ALBEF and BLIP + try: + # ALBEF & ALPRO + fusion_layer = self.config.fusion_layer + add_cross_attention = ( + fusion_layer <= layer_num and self.config.add_cross_attention) + + self.fusion_layer = fusion_layer + except AttributeError: + # BLIP + self.fusion_layer = self.config.num_hidden_layers + add_cross_attention = self.config.add_cross_attention + + # if self.config.add_cross_attention: + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, + is_cross_attention=self.config.add_cross_attention, + layer_num=layer_num, + ) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + # TODO line 482 in albef/models/xbert.py + # compatibility for ALBEF and BLIP + if mode in ['multimodal', 'fusion'] and hasattr( + self, 'crossattention'): + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = (outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = (() if output_attentions + and self.config.add_cross_attention else None) + + next_decoder_cache = () if use_cache else None + + try: + # ALBEF + fusion_layer = self.config.fusion_layer + except AttributeError: + # BLIP + fusion_layer = self.config.num_hidden_layers + + if mode == 'text': + start_layer = 0 + # output_layer = self.config.fusion_layer + output_layer = fusion_layer + + elif mode == 'fusion': + # start_layer = self.config.fusion_layer + start_layer = fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == 'multimodal': + start_layer = 0 + output_layer = self.config.num_hidden_layers + + # compatibility for ALBEF and BLIP + # for i in range(self.config.num_hidden_layers): + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + # TODO pay attention to this. + if self.gradient_checkpointing and self.training: + + if use_cache: + # TODO: logger here + # logger.warn( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@MODELS.register_module() +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + if not isinstance(config, BertConfig): + config = BertConfig.from_dict(config) + + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + ) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] + if past_key_values is not None else 0) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BaseEncoder(nn.Module): + """Base class for primitive encoders, such as ViT, TimeSformer, etc.""" + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +@MODELS.register_module() +class XBertEncoder(BertModel, BaseEncoder): + + def __init__(self, med_config, from_pretrained=False): + + med_config = BertConfig.from_dict(med_config) + super().__init__(config=med_config, add_pooling_layer=False) + + def forward_automask(self, tokenized_text, visual_embeds, **kwargs): + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + text = tokenized_text + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return text_output + + def forward_text(self, tokenized_text, **kwargs): + text = tokenized_text + token_type_ids = kwargs.get('token_type_ids', None) + + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + token_type_ids=token_type_ids, + return_dict=True, + mode='text', + ) + + return text_output + + +@MODELS.register_module() +class Linear(torch.nn.Linear): + """Wrapper for linear function.""" + + +@MODELS.register_module() +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, + BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained( + 'bert-base-cased') + >>> config = BertConfig.from_pretrained( + "bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer( + "Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@MODELS.register_module() +class XBertLMHeadDecoder(BertLMHeadModel): + """This class decouples the decoder forward logic from the VL model. + + In this way, different VL models can share this decoder as long as they + feed encoder_embeds as required. + """ + + def __init__(self, med_config): + self.med_config = BertConfig.from_dict(med_config) + super(XBertLMHeadDecoder, self).__init__(config=self.med_config) + + def generate_from_encoder(self, + tokenized_prompt, + visual_embeds, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + **kwargs): + + if not use_nucleus_sampling: + num_beams = num_beams + visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0) + + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + model_kwargs = { + 'encoder_hidden_states': visual_embeds, + 'encoder_attention_mask': image_atts, + } + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + return outputs diff --git a/mmpretrain/models/multimodal/blip2/Qformer.py b/mmpretrain/models/multimodal/blip2/Qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b85f9ee66020fb86282a89840cc2556a5dec06f --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/Qformer.py @@ -0,0 +1,772 @@ +# flake8: noqa +""" + * Copyright (c) 2023, salesforce.com, inc. +""" +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import apply_chunking_to_forward +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +from mmpretrain.registry import MODELS +from ..blip.language_model import (BertAttention, BertIntermediate, + BertOnlyMLMHead, BertOutput, BertPooler, + BertPreTrainedModel) + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if (self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = (() if output_attentions + and self.config.add_cross_attention else None) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], + prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + if input_ids is None: + assert ( + query_embeds is not None + ), 'You have to specify query_embeds when input_ids is None' + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - + self.config.query_length if past_key_values is not None else 0) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 + tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1]:, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + query_embeds, + past=None, + attention_mask=None, + **model_kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'query_embeds': + query_embeds, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@MODELS.register_module() +class Qformer(BertLMHeadModel): + + def __init__(self, model_style: str, vision_model_width: int, + add_cross_attention: bool, cross_attention_freq: int, + num_query_token: int) -> None: + + config = BertConfig.from_pretrained(model_style) + config.add_cross_attention = add_cross_attention + config.encoder_width = vision_model_width + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + super().__init__(config) diff --git a/mmpretrain/models/multimodal/blip2/__init__.py b/mmpretrain/models/multimodal/blip2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5695f236caf74493fc6e851edbf2a4a05146b5f --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blip2_caption import Blip2Caption +from .blip2_opt_vqa import Blip2VQA +from .blip2_retriever import Blip2Retrieval +from .modeling_opt import OPTForCausalLM +from .Qformer import Qformer + +__all__ = [ + 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'OPTForCausalLM', 'Qformer' +] diff --git a/mmpretrain/models/multimodal/blip2/__pycache__/Qformer.cpython-38.pyc b/mmpretrain/models/multimodal/blip2/__pycache__/Qformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fd22c02c9cafae4c08ace971a20dbc0890619c8 Binary files /dev/null and b/mmpretrain/models/multimodal/blip2/__pycache__/Qformer.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip2/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/blip2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f777cbb3137d35e4eea377f7ae2409977406a80e Binary files /dev/null and b/mmpretrain/models/multimodal/blip2/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip2/__pycache__/blip2_caption.cpython-38.pyc b/mmpretrain/models/multimodal/blip2/__pycache__/blip2_caption.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..744fee1f4b11d9e6ac20c1bb8b903b8489a68219 Binary files /dev/null and b/mmpretrain/models/multimodal/blip2/__pycache__/blip2_caption.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip2/__pycache__/blip2_opt_vqa.cpython-38.pyc b/mmpretrain/models/multimodal/blip2/__pycache__/blip2_opt_vqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b7aa777b68223ac3e6acc4ece804c78c574710e Binary files /dev/null and b/mmpretrain/models/multimodal/blip2/__pycache__/blip2_opt_vqa.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip2/__pycache__/blip2_retriever.cpython-38.pyc b/mmpretrain/models/multimodal/blip2/__pycache__/blip2_retriever.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbdd6aa25c6e552f2bd19e7c23016853a43df5ee Binary files /dev/null and b/mmpretrain/models/multimodal/blip2/__pycache__/blip2_retriever.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip2/__pycache__/modeling_opt.cpython-38.pyc b/mmpretrain/models/multimodal/blip2/__pycache__/modeling_opt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b475d18e10be92fe90047eacc79be757994598bc Binary files /dev/null and b/mmpretrain/models/multimodal/blip2/__pycache__/modeling_opt.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/blip2/blip2_caption.py b/mmpretrain/models/multimodal/blip2/blip2_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..7b409b07acbb84c7e3f15d49ca7a3636beee6004 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_caption.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class Blip2Caption(BaseModel): + """BLIP2 Caption. + + Module for BLIP2 Caption task. + + Args: + vision_backbone (dict): The config dict for vision backbone. + text_backbone (dict): The config dict for text backbone. + multimodal_backbone (dict): The config dict for multimodal backbone. + vision_neck (dict): The config dict for vision neck. + tokenizer: (Optional[dict]): The config for tokenizer. + Defaults to None. + prompt (str): Prompt used for training and eval. + Defaults to ''. + max_txt_len (int): Max text length of input text. + num_captions (int): Number of captions to be generated for each image. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + _no_split_modules = ['BEiTViT', 'OPTDecoderLayer', 'BertLayer'] + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + multimodal_backbone: dict, + vision_neck: dict, + tokenizer: Optional[dict] = None, + prompt: str = '', + max_txt_len: int = 20, + num_captions: int = 1, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.eos_token_id = self.tokenizer( + '\n', add_special_tokens=False).input_ids[0] + + self.vision_backbone = MODELS.build(vision_backbone) + self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims) + + self.vision_neck = MODELS.build(vision_neck) + + self.text_backbone = MODELS.build(text_backbone) + + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.multimodal_backbone.cls = None + self.multimodal_backbone.bert.embeddings.word_embeddings = None + self.multimodal_backbone.bert.embeddings.position_embeddings = None + for layer in self.multimodal_backbone.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.prompt = prompt + self.max_txt_len = max_txt_len + self.num_captions = num_captions + prompt_tokens = self.tokenizer(prompt, return_tensors='pt') + self.prompt_length = prompt_tokens.attention_mask.sum(1) + + self.query_tokens = nn.Parameter( + torch.zeros(1, self.multimodal_backbone.bert.config.query_length, + self.multimodal_backbone.bert.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, + std=self.multimodal_backbone.bert.config.initializer_range) + + # freeze the text backbone + for _, param in self.text_backbone.named_parameters(): + param.requires_grad = False + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._ignore_llm_keys_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss', + ) -> List[DataSample]: + """The unified entry for a forward process in both training and test. + The method should accept two modes: "predict" and "loss": + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): pre_processed img tensor (N, C, ...). + data_samples (List[DataSample], optional): + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> List[DataSample]: + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + + # extract image features from + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + prompt = [self.prompt] * image_embeds.size(0) + + opt_tokens = self.tokenizer( + prompt, return_tensors='pt').to(images.device) + input_ids = opt_tokens.input_ids + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + query_embeds = inputs_opt + + outputs = self.text_backbone.generate( + input_ids=input_ids, + query_embeds=query_embeds, + attention_mask=attention_mask, + do_sample=False, + top_p=0.9, + temperature=1., + num_beams=5, + max_new_tokens=self.max_txt_len, + min_length=1, + eos_token_id=self.eos_token_id, + repetition_penalty=1.0, + length_penalty=1.0, + num_return_sequences=self.num_captions, + ) + + output_text = self.tokenizer.batch_decode( + outputs[:, self.prompt_length:], skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(len(output_text))] + + for data_sample, decode_token in zip(data_samples, output_text): + if data_sample is None: + data_sample = DataSample() + data_sample.pred_caption = decode_token + out_data_samples.append(data_sample) + + return out_data_samples + + @staticmethod + def _ignore_llm_keys_hook(module, incompatible_keys): + """Avoid warning missing keys of the LLM model.""" + import re + llm_pattern = '^text_backbone' + for key in list(incompatible_keys.missing_keys): + if re.match(llm_pattern, key): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py b/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..20e439fa826725a80462557faab8ae25a8e5660e --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .blip2_caption import Blip2Caption + + +@MODELS.register_module() +class Blip2VQA(Blip2Caption): + """BLIP2 VQA. + + Module for BLIP2 VQA task. For more details about the initialization + params, please refer to :class:`Blip2Caption`. + """ + + def predict(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> List[DataSample]: + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + questions = [d.question for d in data_samples] + + # extract image features from + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + prompt = [self.prompt.format(q) for q in questions] + + # use left padding + self.tokenizer.padding_side = 'left' + + opt_tokens = self.tokenizer( + prompt, return_tensors='pt', padding='longest').to(images.device) + input_ids = opt_tokens.input_ids + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + inputs_embeds = self.text_backbone.model.decoder.embed_tokens( + input_ids) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + + outputs = self.text_backbone.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=False, + num_beams=5, + max_new_tokens=self.max_txt_len, + min_length=1, + eos_token_id=self.eos_token_id, + length_penalty=-1.0, + ) + + output_text = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + out_data_samples = [] + for data_sample, decode_token in zip(data_samples, output_text): + data_sample.pred_answer = decode_token + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/blip2/blip2_retriever.py b/mmpretrain/models/multimodal/blip2/blip2_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..e626404a4cde5798151a0fa9589716470ed928a9 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_retriever.py @@ -0,0 +1,505 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import track_iter_progress + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ..blip.blip_retrieval import BlipRetrieval, all_gather_concat + + +@MODELS.register_module() +class Blip2Retrieval(BlipRetrieval): + """BLIP2 Retriever. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. + vision_neck (Optional[dict]): The neck module to process image features + from vision backbone. Defaults to None. + text_neck (Optional[dict]): The neck module to process text features + from text backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed single modality features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal + head module to calculate loss from processed multimodal features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + tokenizer (Optional[dict]): The config for tokenizer. Defaults to None. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + topk (int): Select topk similarity as candidates for compute matching + scores. Notice that this is not the topk in evaluation. + Defaults to 256. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: Optional[dict] = None, + multimodal_backbone: Optional[dict] = None, + vision_neck: Optional[dict] = None, + text_neck: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + multimodal_head: Optional[Union[List[dict], dict]] = None, + tokenizer: Optional[dict] = None, + temperature: float = 0.07, + fast_match: bool = False, + topk: int = 256, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + # Skip BlipRetrieval init + super(BlipRetrieval, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.vision_backbone = MODELS.build(vision_backbone) + self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims) + self.tokenizer = TOKENIZER.build(tokenizer) + + if text_backbone is not None: + self.text_backbone = MODELS.build(text_backbone) + + if multimodal_backbone is not None: + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.multimodal_backbone.resize_token_embeddings( + len(self.tokenizer)) + self.query_tokens = nn.Parameter( + torch.zeros(1, self.multimodal_backbone.bert.config.query_length, + self.multimodal_backbone.bert.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, + std=self.multimodal_backbone.bert.config.initializer_range) + + if vision_neck is not None: + self.vision_neck = MODELS.build(vision_neck) + + if text_neck is not None: + self.text_neck = MODELS.build(text_neck) + + if head is not None: + self.head = MODELS.build(head) + + if multimodal_head is not None: + self.multimodal_head = MODELS.build(multimodal_head) + + self.temp = nn.Parameter(temperature * torch.ones([])) + + # Notice that this topk is used for select k candidate to compute + # image-text score, but not the final metric topk in evaluation. + self.fast_match = fast_match + self.topk = topk + + def _extract_feat(self, inputs: Union[torch.Tensor, dict], + modality: str) -> Tuple[torch.Tensor]: + """Extract features from the single modality. + Args: + inputs (Union[torch.Tensor, dict]): A batch of inputs. + For image, a tensor of shape (N, C, ...) in general. + For text, a dict of tokenized text inputs. + modality (str): Modality feature to be extracted. Only two + options are supported. + + - ``images``: Only extract image features, mostly used for + inference. + - ``texts``: Only extract text features, mostly used for + inference. + Returns: + Tuple[torch.Tensor]: The output features. + """ + if modality == 'images': + # extract image features + # TODO: + # Add layernorm inside backbone and handle the concat outside + image_embeds = self.ln_vision_backbone( + self.vision_backbone(inputs)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, + -1) + query_output = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + use_cache=True, + return_dict=True, + ) + image_feat = F.normalize( + self.vision_neck([query_output.last_hidden_state]), dim=-1) + return { + 'image_embeds': image_embeds, + 'image_feat': image_feat, + 'query_output': query_output + } + elif modality == 'texts': + # extract text features + text_output = self.multimodal_backbone.bert( + inputs.input_ids, + attention_mask=inputs.attention_mask, + return_dict=True, + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_neck([text_embeds[:, 0, :]]), dim=-1) + return {'text_embeds': text_embeds, 'text_feat': text_feat} + else: + raise RuntimeError(f'Invalid modality "{modality}".') + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + both head and multimodal head. + """ + output = self.extract_feat(images, data_samples) + + text_ids = output['text_ids'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + query_output = output['query_output'] + + # ITC Loss + # B*world_size, num_query, D + image_feat_all = torch.cat(dist.all_gather(image_feat)) + # B*world_size, D + text_feat_all = torch.cat(dist.all_gather(text_feat)) + + # B, B*world_size, num_query + sim_q2t = torch.matmul( + image_feat.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze() + + # image to text similarity + sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # B, B*world_size, num_query + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), + image_feat_all.permute(0, 2, 1)).squeeze() + + # text-image similarity + sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp + + rank = dist.get_rank() + bs = images.size(0) + targets = torch.linspace( + rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) + + itc_loss = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2 + + # prepare for itm + text_input_ids_world = torch.cat(dist.all_gather(text_ids)) + text_attention_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + image_embeds_world = torch.cat(dist.all_gather(image_embeds)) + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 + weights_t2i[:, rank * bs:rank * bs + bs].fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 + weights_i2t[:, rank * bs:rank * bs + bs].fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(text_input_ids_world[neg_idx]) + text_atts_neg.append(text_attention_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([text_ids, text_ids, text_ids_neg], + dim=0) # pos, pos, neg + text_atts_all = torch.cat( + [text_attn_mask, text_attn_mask, text_atts_neg], + dim=0, + ) + + query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, + -1) + query_atts_itm = torch.ones( + query_tokens_itm.size()[:-1], dtype=torch.long).to(self.device) + attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) + + image_embeds_all = torch.cat( + [image_embeds, image_embeds_neg, image_embeds], + dim=0) # pos, neg, pos + image_atts_all = torch.ones( + image_embeds_all.size()[:-1], dtype=torch.long).to(self.device) + + output_itm = self.multimodal_backbone.bert( + text_ids_all, + query_embeds=query_tokens_itm, + attention_mask=attention_mask_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = output_itm.last_hidden_state[:, :query_tokens_itm. + size(1), :] + + # create false data samples + data_samples.extend( + [DataSample(is_matched=False) for _ in range(2 * bs)]) + loss_multimodal = self.multimodal_head.loss((vl_embeddings, ), + data_samples) + + # LM loss + decoder_input_ids = text_ids.clone() + decoder_input_ids[:, 0] = self.tokenizer.bos_token_id + labels = decoder_input_ids.masked_fill( + decoder_input_ids == self.tokenizer.pad_token_id, -100) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text_attn_mask], dim=1) + lm_output = self.multimodal_backbone( + decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output.past_key_values, + return_dict=True, + labels=labels, + ) + + return dict( + itc_loss=itc_loss, **loss_multimodal, lm_loss=lm_output.loss) + + def predict_all(self, + feats: Dict[str, torch.Tensor], + data_samples: List[DataSample], + num_images: int = None, + num_texts: int = None, + cal_i2t: bool = True, + cal_t2i: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute similarity matrix between images and texts across all ranks. + + Args: + feats (Dict[str, torch.Tensor]): Features from the current rank. + data_samples (List[DataSample]): Data samples from the current + rank. + num_images (int, optional): Number of images to use. + Defaults to None. + num_texts (int, optional): Number of texts to use. + Defaults to None. + cal_i2t (bool, optional): Whether to compute image-to-text + similarity. Defaults to True. + cal_t2i (bool, optional): Whether to compute text-to-image + similarity. Defaults to True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Image-to-text and text-to-image + similarity matrices. + """ + text_ids = feats['text_ids'] + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + if not self.fast_match: + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + else: + image_embeds_all = None + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_ids_all = all_gather_concat(text_ids)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_ids_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_ids, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats: torch.Tensor, + img_embeds: List[torch.Tensor], + text_feats: torch.Tensor, + text_ids: torch.Tensor, + text_atts: torch.Tensor) -> torch.Tensor: + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input tensor with shape (M, C). + M stands for numbers of samples on a single GPU. + img_embeds (List[torch.Tensor]): Image features from each layer of + the vision backbone. + text_feats (torch.Tensor): The input tensor with shape (N, C). + N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + + # compute i2t sim matrix + # TODO: check correctness + sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1) + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + + for i in track_iter_progress(range(img_feats.size(0))): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + # get repeated image embeddings + encoder_output = img_embeds[i].repeat(self.topk, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + # query embeds and attention masks + query_tokens = self.query_tokens.expand(encoder_output.shape[0], + -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text_atts[topk_idx]], + dim=1) + output = self.multimodal_backbone.bert( + text_ids[topk_idx], + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, :query_tokens.size(1), :], + ))[:, :, 1].mean(dim=1) + score_matrix_i2t[i, topk_idx] = score + topk_sim + + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats: torch.Tensor, + img_embeds: List[torch.Tensor], + text_feats: torch.Tensor, + text_ids: torch.Tensor, + text_atts: torch.Tensor) -> torch.Tensor: + """Compare the score matrix for text-to-image retrieval. + + Every text should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input tensor with shape (N, C). + N stands for numbers of all samples on all GPUs. + img_embeds (List[torch.Tensor]): Image features from each layer of + the vision backbone. + text_feats (torch.Tensor): The input tensor with shape (M, C). + M stands for numbers of samples on a single GPU. + text_ids (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + + # compute t2i sim matrix + # TODO: check correctness + sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1) + sim_matrix_t2i = sim_matrix_i2t.t() + if self.fast_match: + return sim_matrix_i2t + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + + for i in track_iter_progress(range(text_feats.size(0))): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + # get topk image embeddings + encoder_output = img_embeds[topk_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + # get query embeds and attention masks + query_tokens = self.query_tokens.expand(encoder_output.shape[0], + -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat( + [query_atts, text_atts[i].repeat(self.topk, 1)], dim=1) + output = self.multimodal_backbone.bert( + text_ids[i].repeat(self.topk, 1), + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, :query_tokens.size(1), :], + ))[:, :, 1].mean(dim=1) + score_matrix_t2i[i, topk_idx] = score + topk_sim + + return score_matrix_t2i diff --git a/mmpretrain/models/multimodal/blip2/modeling_opt.py b/mmpretrain/models/multimodal/blip2/modeling_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..7cde0d76a2079a610bd71ed034c0c88940244e76 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/modeling_opt.py @@ -0,0 +1,1083 @@ +# flake8: noqa +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +from mmpretrain.models.utils import register_hf_model + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'facebook/opt-350m' +_CONFIG_FOR_DOC = 'OPTConfig' +_TOKENIZER_FOR_DOC = 'GPT2Tokenizer' + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'facebook/opt-125m', + 'facebook/opt-350m', + 'facebook/opt-1.3b', + 'facebook/opt-2.7b', + 'facebook/opt-6.7b', + 'facebook/opt-13b', + 'facebook/opt-30b', + # See all OPT models at https://huggingface.co/models?filter=opt +] + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + """Make causal mask used for bi-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, + src_seq_len]`.""" + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + torch.finfo(dtype).min) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """This module learns positional embeddings up to a fixed maximum size.""" + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * + attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}' + f' and `num_heads`: {num_heads}).') + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return (tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous()) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel.""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f'Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is' + f' {attn_weights.size()}') + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}' + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask) + attn_weights = torch.max( + attn_weights, + torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads, ): + raise ValueError( + f'Head mask for a single layer should be of size {(self.num_heads,)}, but is' + f' {layer_head_mask.size()}') + attn_weights = layer_head_mask.view( + 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare OPT Model outputting raw hidden-states without any specific head on top.', + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + + config_class = OPTConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['OPTDecoderLayer'] + _keys_to_ignore_on_load_unexpected = [r'decoder\.version'] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (OPTDecoder)): + module.gradient_checkpointing = value + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers* layers. + Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.word_embed_proj_dim, + self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear( + config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear( + config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + 'You have to specify either decoder_input_ids or decoder_inputs_embeds' + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + input_shape = inputs_embeds.size()[:-1] + else: + input_shape = (batch_size, seq_length) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + inputs_embeds.shape[:2], + dtype=torch.bool, + device=inputs_embeds.device) + pos_embeds = self.embed_positions(attention_mask, + past_key_values_length) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ['head_mask']): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f'The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for' + f' {head_mask.size()[0]}.') + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + 'The bare OPT Model outputting raw hidden-states without any specific head on top.', + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +@register_hf_model() +class OPTForCausalLM(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r'lm_head.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear( + config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = 'mean', + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import GPT2Tokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + logits = logits[:, -labels.size(1):, :] + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + loss = loss_fct( + shift_logits.view(-1, self.config.vocab_size), + shift_labels.view(-1)) + if reduction == 'none': + loss = loss.view(shift_logits.size(0), -1).sum(1) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + query_embeds=None, + past_key_values=None, + attention_mask=None, + use_cache=None, + **kwargs, + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + if input_ids is not None: + attention_mask = input_ids.new_ones(input_ids.shape) + if past_key_values: + input_ids = input_ids[:, -1:] + query_embeds = None + # first step, decoder_cached_states are empty + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'query_embeds': query_embeds, + 'attention_mask': attention_mask, + 'past_key_values': past_key_values, + 'use_cache': use_cache, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/chinese_clip/__init__.py b/mmpretrain/models/multimodal/chinese_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..460e9e6a6be748113df029ad76bc0934ab7704d3 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bert import BertModelCN +from .chinese_clip import ChineseCLIP, ModifiedResNet + +__all__ = ['ChineseCLIP', 'ModifiedResNet', 'BertModelCN'] diff --git a/mmpretrain/models/multimodal/chinese_clip/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/chinese_clip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd7eea4704fc4fc64303d4b58813465942c80d06 Binary files /dev/null and b/mmpretrain/models/multimodal/chinese_clip/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/chinese_clip/__pycache__/bert.cpython-38.pyc b/mmpretrain/models/multimodal/chinese_clip/__pycache__/bert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa09cd7bc7aa23ea1cb064d9111ab0b32dba2d8 Binary files /dev/null and b/mmpretrain/models/multimodal/chinese_clip/__pycache__/bert.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/chinese_clip/__pycache__/chinese_clip.cpython-38.pyc b/mmpretrain/models/multimodal/chinese_clip/__pycache__/chinese_clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19051ace56bf11a22801653a2f9166d911682325 Binary files /dev/null and b/mmpretrain/models/multimodal/chinese_clip/__pycache__/chinese_clip.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/chinese_clip/__pycache__/utils.cpython-38.pyc b/mmpretrain/models/multimodal/chinese_clip/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1038dd37f53713c5e0e3f62ac7e74e2cce46beab Binary files /dev/null and b/mmpretrain/models/multimodal/chinese_clip/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/chinese_clip/bert.py b/mmpretrain/models/multimodal/chinese_clip/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8dc7322a9aaddb0f5e02f8b70597ba08a8b925 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/bert.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + +# flake8: noqa +import math + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +try: + from transformers.models.bert.configuration_bert import BertConfig +except: + BertConfig = None + +from mmpretrain.registry import MODELS +from ..blip.language_model import BertAttention, BertIntermediate, BertOutput + + +def gelu(x): + """Original Implementation of the gelu activation function in Google Bert + repo when initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives + slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ # noqa + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu_new(x): + """Implementation of the gelu activation function currently in Google Bert + repo (identical to OpenAI GPT) https://arxiv.org/abs/1606.08415.""" + return 0.5 * x * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = { + 'gelu': gelu, + 'relu': torch.nn.functional.relu, + 'swish': swish, + 'gelu_new': gelu_new +} + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type + embeddings.""" + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings \ + + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[ + 1:] # add attentions if we output them + if len(outputs) == 1: + return outputs[0] + return outputs + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.grad_checkpointing = False + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + layer_outputs = checkpoint(layer_module, hidden_states, + attention_mask, head_mask[i]) + else: + layer_outputs = layer_module(hidden_states, attention_mask, + head_mask[i]) + if not isinstance(layer_outputs, tuple): + layer_outputs = (layer_outputs, ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + # last-layer hidden state, (all hidden states), (all attentions) + return outputs + + +class BertPreTrainedModel(nn.Module): + base_model_prefix = 'bert' + + def __init__(self, config): + super(BertPreTrainedModel, self).__init__() + self.config = config + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@MODELS.register_module() +class BertModelCN(BertPreTrainedModel): + """The BERT model implementation for Chinese CLIP.""" + + def __init__(self, config): + config = BertConfig.from_dict(config) + super(BertModelCN, self).__init__(config) + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.apply(self._init_weights) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + if enable: + assert not self.config.output_attentions, \ + 'Grad checkpointing is currently conflict with ' \ + 'output_attentions for BertEncoder, ' \ + 'please set it to False in BertConfig' + + self.encoder.grad_checkpointing = enable + + def forward(self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( + -1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, + -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters( + )).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + encoder_outputs = self.encoder( + embedding_output, extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + # pooled_output = self.pooler(sequence_output) + pooled_output = None + + # add hidden_states and attentions if they are here + outputs = ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + # sequence_output, pooled_output, (hidden_states), (attentions) + return outputs diff --git a/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py b/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..40af5643602685be4d0e37331609bdecae184de9 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py @@ -0,0 +1,446 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel, BaseModule +from torch import nn + +from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import OPENAI_PROMPT + +PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN} +PROMPT_MAP = {'openai': OPENAI_PROMPT} + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +@MODELS.register_module() +class ModifiedResNet(BaseModule): + """A modified ResNet contains the following changes: + + - Apply deep stem with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is + prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ # noqa + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth: int = 50, + base_channels: int = 64, + input_size: int = 224, + num_attn_heads: int = 32, + output_dim: int = 1024, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.input_size = input_size + self.block, stage_blocks = self.arch_settings[depth] + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, + base_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(base_channels // 2) + self.conv2 = nn.Conv2d( + base_channels // 2, + base_channels // 2, + kernel_size=3, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(base_channels // 2) + self.conv3 = nn.Conv2d( + base_channels // 2, + base_channels, + kernel_size=3, + padding=1, + bias=False) + self.bn3 = nn.BatchNorm2d(base_channels) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + # this is a *mutable* variable used during construction + self._inplanes = base_channels + self.layer1 = self._make_layer(base_channels, stage_blocks[0]) + self.layer2 = self._make_layer( + base_channels * 2, stage_blocks[1], stride=2) + self.layer3 = self._make_layer( + base_channels * 4, stage_blocks[2], stride=2) + self.layer4 = self._make_layer( + base_channels * 8, stage_blocks[3], stride=2) + + embed_dim = base_channels * 32 + self.attnpool = AttentionPool2d(input_size // 32, embed_dim, + num_attn_heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +@MODELS.register_module() +class ChineseCLIP(BaseModel): + """The implementation of `ChineseCLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. Defaults to 'openai'. + context_length (int): The context length to use. Defaults to 52. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + tokenizer: dict, + proj_dim: int, + text_prototype: Union[str, List[str]], + text_prompt: str = 'openai', + context_length: int = 52, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.vision_backbone = MODELS.build(vision_backbone) + self.text_backbone = MODELS.build(text_backbone) + + if not isinstance(self.vision_backbone, ModifiedResNet): + self.vision_projection = nn.Parameter( + torch.empty(self.vision_backbone.embed_dims, proj_dim)) + text_hidden_size = text_backbone['config']['hidden_size'] + self.text_projection = nn.Parameter( + torch.empty(text_hidden_size, proj_dim)) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.context_length = context_length + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + if isinstance(self.vision_backbone, ModifiedResNet): + return self.vision_backbone(images) + return self.vision_backbone(images)[-1] @ self.vision_projection + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + pad_index = self.tokenizer.vocab['[PAD]'] + attn_mask = texts.ne(pad_index) + # [batch_size, seq_length, hidden_size] + x = self.text_backbone(texts, attention_mask=attn_mask)[0] + return x[:, 0, :] @ self.text_projection + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['[CLS]']] + + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['[SEP]']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/mmpretrain/models/multimodal/chinese_clip/utils.py b/mmpretrain/models/multimodal/chinese_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6964722bd3dbb05a6a59a1dc2c57c0a6e8692c31 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/utils.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +OPENAI_PROMPT = [ + lambda c: f'{c}的照片', + lambda c: f'质量差的{c}的照片', + lambda c: f'许多{c}的照片', + lambda c: f'{c}的雕塑', + lambda c: f'难以看到{c}的照片', + lambda c: f'{c}的低分辨率照片', + lambda c: f'{c}的渲染', + lambda c: f'涂鸦{c}', + lambda c: f'{c}的糟糕照片', + lambda c: f'{c}的裁剪照片', + lambda c: f'{c}的纹身', + lambda c: f'{c}的刺绣照片', + lambda c: f'很难看到{c}的照片', + lambda c: f'{c}的明亮照片', + lambda c: f'一张干净的{c}的照片', + lambda c: f'一张包含{c}的照片', + lambda c: f'{c}的深色照片', + lambda c: f'{c}的手绘画', + lambda c: f'我的{c}的照片', + lambda c: f'不自然的{c}的照片', + lambda c: f'一张酷的{c}的照片', + lambda c: f'{c}的特写照片', + lambda c: f'{c}的黑白照片', + lambda c: f'一幅{c}的画', + lambda c: f'一幅{c}的绘画', + lambda c: f'一张{c}的像素照片', + lambda c: f'{c}的雕像', + lambda c: f'一张{c}的明亮照片', + lambda c: f'{c}的裁剪照片', + lambda c: f'人造的{c}的照片', + lambda c: f'一张关于{c}的照片', + lambda c: f'损坏的{c}的jpeg照片', + lambda c: f'{c}的模糊照片', + lambda c: f'{c}的相片', + lambda c: f'一张{c}的好照片', + lambda c: f'{c}的渲染照', + lambda c: f'视频游戏中的{c}', + lambda c: f'一张{c}的照片', + lambda c: f'{c}的涂鸦', + lambda c: f'{c}的近距离照片', + lambda c: f'{c}的折纸', + lambda c: f'{c}在视频游戏中', + lambda c: f'{c}的草图', + lambda c: f'{c}的涂鸦照', + lambda c: f'{c}的折纸形状', + lambda c: f'低分辨率的{c}的照片', + lambda c: f'玩具{c}', + lambda c: f'{c}的副本', + lambda c: f'{c}的干净的照片', + lambda c: f'一张大{c}的照片', + lambda c: f'{c}的重现', + lambda c: f'一张漂亮的{c}的照片', + lambda c: f'一张奇怪的{c}的照片', + lambda c: f'模糊的{c}的照片', + lambda c: f'卡通{c}', + lambda c: f'{c}的艺术作品', + lambda c: f'{c}的素描', + lambda c: f'刺绣{c}', + lambda c: f'{c}的像素照', + lambda c: f'{c}的拍照', + lambda c: f'{c}的损坏的照片', + lambda c: f'高质量的{c}的照片', + lambda c: f'毛绒玩具{c}', + lambda c: f'漂亮的{c}的照片', + lambda c: f'小{c}的照片', + lambda c: f'照片是奇怪的{c}', + lambda c: f'漫画{c}', + lambda c: f'{c}的艺术照', + lambda c: f'{c}的图形', + lambda c: f'大{c}的照片', + lambda c: f'黑白的{c}的照片', + lambda c: f'{c}毛绒玩具', + lambda c: f'一张{c}的深色照片', + lambda c: f'{c}的摄影图', + lambda c: f'{c}的涂鸦照', + lambda c: f'玩具形状的{c}', + lambda c: f'拍了{c}的照片', + lambda c: f'酷酷的{c}的照片', + lambda c: f'照片里的小{c}', + lambda c: f'{c}的刺青', + lambda c: f'{c}的可爱的照片', + lambda c: f'一张{c}可爱的照片', + lambda c: f'{c}可爱图片', + lambda c: f'{c}酷炫图片', + lambda c: f'一张{c}的酷炫的照片', + lambda c: f'一张{c}的酷炫图片', + lambda c: f'这是{c}', + lambda c: f'{c}的好看照片', + lambda c: f'一张{c}的好看的图片', + lambda c: f'{c}的好看图片', + lambda c: f'{c}的照片。', + lambda c: f'质量差的{c}的照片。', + lambda c: f'许多{c}的照片。', + lambda c: f'{c}的雕塑。', + lambda c: f'难以看到{c}的照片。', + lambda c: f'{c}的低分辨率照片。', + lambda c: f'{c}的渲染。', + lambda c: f'涂鸦{c}。', + lambda c: f'{c}的糟糕照片。', + lambda c: f'{c}的裁剪照片。', + lambda c: f'{c}的纹身。', + lambda c: f'{c}的刺绣照片。', + lambda c: f'很难看到{c}的照片。', + lambda c: f'{c}的明亮照片。', + lambda c: f'一张干净的{c}的照片。', + lambda c: f'一张包含{c}的照片。', + lambda c: f'{c}的深色照片。', + lambda c: f'{c}的手绘画。', + lambda c: f'我的{c}的照片。', + lambda c: f'不自然的{c}的照片。', + lambda c: f'一张酷的{c}的照片。', + lambda c: f'{c}的特写照片。', + lambda c: f'{c}的黑白照片。', + lambda c: f'一幅{c}的画。', + lambda c: f'一幅{c}的绘画。', + lambda c: f'一张{c}的像素照片。', + lambda c: f'{c}的雕像。', + lambda c: f'一张{c}的明亮照片。', + lambda c: f'{c}的裁剪照片。', + lambda c: f'人造的{c}的照片。', + lambda c: f'一张关于{c}的照片。', + lambda c: f'损坏的{c}的jpeg照片。', + lambda c: f'{c}的模糊照片。', + lambda c: f'{c}的相片。', + lambda c: f'一张{c}的好照片。', + lambda c: f'{c}的渲染照。', + lambda c: f'视频游戏中的{c}。', + lambda c: f'一张{c}的照片。', + lambda c: f'{c}的涂鸦。', + lambda c: f'{c}的近距离照片。', + lambda c: f'{c}的折纸。', + lambda c: f'{c}在视频游戏中。', + lambda c: f'{c}的草图。', + lambda c: f'{c}的涂鸦照。', + lambda c: f'{c}的折纸形状。', + lambda c: f'低分辨率的{c}的照片。', + lambda c: f'玩具{c}。', + lambda c: f'{c}的副本。', + lambda c: f'{c}的干净的照片。', + lambda c: f'一张大{c}的照片。', + lambda c: f'{c}的重现。', + lambda c: f'一张漂亮的{c}的照片。', + lambda c: f'一张奇怪的{c}的照片。', + lambda c: f'模糊的{c}的照片。', + lambda c: f'卡通{c}。', + lambda c: f'{c}的艺术作品。', + lambda c: f'{c}的素描。', + lambda c: f'刺绣{c}。', + lambda c: f'{c}的像素照。', + lambda c: f'{c}的拍照。', + lambda c: f'{c}的损坏的照片。', + lambda c: f'高质量的{c}的照片。', + lambda c: f'毛绒玩具{c}。', + lambda c: f'漂亮的{c}的照片。', + lambda c: f'小{c}的照片。', + lambda c: f'照片是奇怪的{c}。', + lambda c: f'漫画{c}。', + lambda c: f'{c}的艺术照。', + lambda c: f'{c}的图形。', + lambda c: f'大{c}的照片。', + lambda c: f'黑白的{c}的照片。', + lambda c: f'{c}毛绒玩具。', + lambda c: f'一张{c}的深色照片。', + lambda c: f'{c}的摄影图。', + lambda c: f'{c}的涂鸦照。', + lambda c: f'玩具形状的{c}。', + lambda c: f'拍了{c}的照片。', + lambda c: f'酷酷的{c}的照片。', + lambda c: f'照片里的小{c}。', + lambda c: f'{c}的刺青。', + lambda c: f'{c}的可爱的照片。', + lambda c: f'一张{c}可爱的照片。', + lambda c: f'{c}可爱图片。', + lambda c: f'{c}酷炫图片。', + lambda c: f'一张{c}的酷炫的照片。', + lambda c: f'一张{c}的酷炫图片。', + lambda c: f'这是{c}。', + lambda c: f'{c}的好看照片。', + lambda c: f'一张{c}的好看的图片。', + lambda c: f'{c}的好看图片。', + lambda c: f'一种叫{c}的花的照片', + lambda c: f'一种叫{c}的食物的照片', + lambda c: f'{c}的卫星照片', +] diff --git a/mmpretrain/models/multimodal/flamingo/__init__.py b/mmpretrain/models/multimodal/flamingo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0bfd63b657f5f0f1517ad6d31bce2821cb372cd --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adapter import FlamingoLMAdapter +from .flamingo import Flamingo + +__all__ = ['Flamingo', 'FlamingoLMAdapter'] diff --git a/mmpretrain/models/multimodal/flamingo/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/flamingo/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c30179ff0484f9b36c1ddf2fc8efd7aa796e999 Binary files /dev/null and b/mmpretrain/models/multimodal/flamingo/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/flamingo/__pycache__/adapter.cpython-38.pyc b/mmpretrain/models/multimodal/flamingo/__pycache__/adapter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2013b9970315f0e1239cc21f5417aa35d7ba18e6 Binary files /dev/null and b/mmpretrain/models/multimodal/flamingo/__pycache__/adapter.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/flamingo/__pycache__/flamingo.cpython-38.pyc b/mmpretrain/models/multimodal/flamingo/__pycache__/flamingo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..341a350ae72591af63dbf15e6032f207b0056711 Binary files /dev/null and b/mmpretrain/models/multimodal/flamingo/__pycache__/flamingo.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/flamingo/__pycache__/modules.cpython-38.pyc b/mmpretrain/models/multimodal/flamingo/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43577d993ec9a2b3f799bca682496229040c3d7f Binary files /dev/null and b/mmpretrain/models/multimodal/flamingo/__pycache__/modules.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/flamingo/__pycache__/utils.cpython-38.pyc b/mmpretrain/models/multimodal/flamingo/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8adcee7fc10225f50e7ff2e12ec4a624f1b98b86 Binary files /dev/null and b/mmpretrain/models/multimodal/flamingo/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/flamingo/adapter.py b/mmpretrain/models/multimodal/flamingo/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..bef0e2f86bfbe81046bb25fa4b9915e4c4f9005a --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/adapter.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .modules import FlamingoLayer, GatedCrossAttentionBlock +from .utils import getattr_recursive, setattr_recursive + + +@MODELS.register_module() +class FlamingoLMAdapter: + """Mixin to add cross-attention layers to a language model.""" + + @classmethod + def extend_init( + cls, + base: object, + vis_hidden_size: int, + cross_attn_every_n_layers: int, + use_media_placement_augmentation: bool, + only_attend_previous: bool = False, + ): + """Initialize Flamingo by adding a new gated cross attn to the decoder. + + Store the media token id for computing the media locations. + + Args: + base (object): Base module could be any object that represent + a instance of language model. + vis_hidden_size: (int): Hidden size of vision embeddings. + cross_attn_every_n_layers: (int): Additional cross attn for + every n layers. + use_media_placement_augmentation: (bool): Whether to use media + placement augmentation. + """ + base.set_decoder_layers_attr_name('model.layers') + gated_cross_attn_layers = nn.ModuleList([ + GatedCrossAttentionBlock( + dim=base.config.hidden_size, dim_visual=vis_hidden_size) if + (layer_idx + 1) % cross_attn_every_n_layers == 0 else None + for layer_idx, _ in enumerate(base._get_decoder_layers()) + ]) + base._set_decoder_layers( + nn.ModuleList([ + FlamingoLayer(gated_cross_attn_layer, decoder_layer) + for gated_cross_attn_layer, decoder_layer in zip( + gated_cross_attn_layers, base._get_decoder_layers()) + ])) + base.use_media_placement_augmentation = use_media_placement_augmentation # noqa + base.initialized_flamingo = True + base.only_attend_previous = only_attend_previous + return base + + def set_decoder_layers_attr_name(self, decoder_layers_attr_name): + """Set decoder layers attribute name.""" + self.decoder_layers_attr_name = decoder_layers_attr_name + + def _get_decoder_layers(self): + """Get decoder layers according to attribute name.""" + return getattr_recursive(self, self.decoder_layers_attr_name) + + def _set_decoder_layers(self, value): + """Set decoder layers according to attribute name.""" + setattr_recursive(self, self.decoder_layers_attr_name, value) + + def forward(self, *input, **kwargs): + """Condition the Flamingo layers on the media locations before forward + function.""" + input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0] + media_locations = input_ids == self.media_token_id + if self.only_attend_previous: + attend_previous = True + elif self.use_media_placement_augmentation: + attend_previous = (random.random() < 0.5) + else: + attend_previous = False + + for layer in self.get_decoder().layers: + layer.condition_media_locations(media_locations) + layer.condition_attend_previous(attend_previous) + + return super().forward( + *input, **kwargs) # Call the other parent's forward method + + def is_conditioned(self) -> bool: + """Check whether all decoder layers are already conditioned.""" + return all(layer.is_conditioned() + for layer in self._get_decoder_layers()) + + def clear_conditioned_layers(self): + """Clear all conditional layers.""" + for layer in self._get_decoder_layers(): + layer.condition_vis_x(None) + layer.condition_media_locations(None) + layer.condition_attend_previous(None) diff --git a/mmpretrain/models/multimodal/flamingo/flamingo.py b/mmpretrain/models/multimodal/flamingo/flamingo.py new file mode 100644 index 0000000000000000000000000000000000000000..abdd03328f4a22b0e4c2c37598d6e5517555994d --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/flamingo.py @@ -0,0 +1,322 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .modules import PerceiverResampler +from .utils import ExtendModule + + +@MODELS.register_module() +class Flamingo(BaseModel): + """The Open Flamingo model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to 'Output:'. + shot_prompt_tmpl (str): Prompt used for few-shot inference. + Defaults to 'Output:{caption}<|endofchunk|>'. + final_prompt_tmpl (str): Final part of prompt used for inference. + Defaults to 'Output:'. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + _no_split_modules = [ + 'TransformerEncoderLayer', 'PerceiverAttention', + 'GatedCrossAttentionBlock', 'FlamingoLayer' + ] + + def __init__( + self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + zeroshot_prompt: str = 'Output:', + shot_prompt_tmpl: str = 'Output:{caption}<|endofchunk|>', + final_prompt_tmpl: str = 'Output:', + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Flamingo special tokens to the tokenizer + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['<|endofchunk|>', '']}) + self.tokenizer.bos_token_id = 1 + if self.tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + self.tokenizer.add_special_tokens({'pad_token': ''}) + + # Template to format the prompt input + self.zeroshot_prompt = zeroshot_prompt + self.shot_prompt_tmpl = shot_prompt_tmpl + self.final_prompt_tmpl = final_prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + self.vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + + self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) + + # init language encoder related modules + self.lang_encoder = ExtendModule(**lang_encoder) + self.lang_encoder.resize_token_embeddings(len(self.tokenizer)) + self.lang_encoder.media_token_id = self.tokenizer.encode('')[-1] + + # other necessary parameters + self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1] + self.generation_cfg = { + 'num_beams': 1, + 'max_new_tokens': None, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 1.0, + 'no_repeat_ngram_size': 0, + 'prefix_allowed_tokens_fn': None, + 'length_penalty': 1.0, + 'num_return_sequences': 1, + 'do_sample': False, + 'early_stopping': False, + **generation_cfg, + } + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_adapter_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): The input image tensor with different ndim + according to the inputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_vision_feats(self, images: torch.Tensor) -> torch.Tensor: + """Extract vision features. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + + Returns: + torch.Tensor: Return extracted features. + """ + if images.ndim == 4: + # (B, C, H, W) -> (B, 1, C, H, W) for zero-shot. + images = images.unsqueeze(1) + b, T = images.shape[:2] + # b T c h w -> (b T) c h w + images = images.view(b * T, *images.shape[-3:]) + + with torch.no_grad(): + vision_feats = self.vision_encoder(images)[-1][:, 1:] + + # (b T F) v d -> b T F v d Only support F=1 here + vision_feats = vision_feats.view(b, T, 1, *vision_feats.shape[-2:]) + + vision_feats = self.perceiver(vision_feats) # reshapes to (b, T, n, d) + return vision_feats + + def predict(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **generation_cfg): + """Predict generation results from a batch of inputs. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **generation_cfg: Other keyword arguments accepted by the + ``generate`` method of :attr:`lang_encoder`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # generation_cfg in prediction should be dominant + generation_cfg = {**self.generation_cfg, **generation_cfg} + num_beams = generation_cfg['num_beams'] + + if num_beams > 1: + images = images.repeat_interleave(num_beams, dim=0) + + # extra vision feats and set as language condition feats + vision_x = self.extract_vision_feats(images) + for layer in self.lang_encoder._get_decoder_layers(): + layer.condition_vis_x(vision_x) + + input_text = self.preprocess_text(data_samples, device=images.device) + + outputs = self.lang_encoder.generate( + input_text.input_ids, + attention_mask=input_text.attention_mask, + eos_token_id=self.eoc_token_id, + **generation_cfg) + + # clear conditioned layers for language models + self.lang_encoder.clear_conditioned_layers() + + # remove prefix + outputs = outputs[:, len(input_text.input_ids[0]):] + + return self.post_process(outputs, data_samples) + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + if 'shots' in sample: + # few-shot + shot_prompt = ''.join([ + self.shot_prompt_tmpl.format(**shot) + for shot in sample.get('shots') + ]) + else: + # zero-shot + shot_prompt = self.zeroshot_prompt + + # add final prompt + final_prompt = self.final_prompt_tmpl.format(**sample.to_dict()) + prompts.append(shot_prompt + final_prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = re.split('Output', output, + 1)[0].replace('"', '') + elif self.task == 'vqa': + data_sample.pred_answer = re.split('Question|Answer', output, + 1)[0] + + return data_samples + + @staticmethod + def _load_adapter_hook(module, incompatible_keys): + """Avoid warning missing keys except adapter keys.""" + adapter_patterns = [ + '^perceiver', + 'lang_encoder.*embed_tokens', + 'lang_encoder.*gated_cross_attn_layers', + 'lang_encoder.*rotary_emb', + ] + for key in list(incompatible_keys.missing_keys): + if not any(re.match(pattern, key) for pattern in adapter_patterns): + incompatible_keys.missing_keys.remove(key) + + for key in list(incompatible_keys.unexpected_keys): + if 'position_ids' in key: + incompatible_keys.unexpected_keys.remove(key) + if 'lang_encoder.gated_cross_attn_layers' in key: + incompatible_keys.unexpected_keys.remove(key) diff --git a/mmpretrain/models/multimodal/flamingo/modules.py b/mmpretrain/models/multimodal/flamingo/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..730c61b68a8d0fb799b7985636f09b6484ef99c2 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/modules.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Taken from https://github.com/lucidrains/flamingo-pytorch.""" + +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + + +def FeedForward(dim, mult: int = 4): + """Feedforward layers. + + Args: + mult (int): Layer expansion muliplier. Defaults to 4. + """ + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + """Perceiver attetion layers. + + Args: + dim (int): Input dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + """ + + def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x: torch.Tensor, latents: torch.Tensor): + """Forward function. + + Args: + x (torch.Tensor): image features of shape (b, T, n1, D). + latent (torch.Tensor): latent features of shape (b, T, n2, D). + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q = rearrange(q, 'b t n (h d) -> b h t n d', h=h) + k = rearrange(k, 'b t n (h d) -> b h t n d', h=h) + v = rearrange(v, 'b t n (h d) -> b h t n d', h=h) + q = q * self.scale + + # attention + sim = einsum('... i d, ... j d -> ... i j', q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h t n d -> b t n (h d)', h=h) + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + """Perceiver resampler layers. + + Args: + dim (int): Input dimensions. + depth (int): Depth of resampler. Defaults to 6. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + num_latents (int): Number of latents. Defaults to 64. + max_num_media (int, optional): Max number of media. + Defaults to None. + max_num_frames (int, optional): Max number of frames. + Defaults to None. + ff_mult (int): Feed forward multiplier. Defaults to 4. + """ + + def __init__( + self, + *, + dim: int, + depth: int = 6, + dim_head: int = 64, + heads: int = 8, + num_latents: int = 64, + max_num_media: Optional[int] = None, + max_num_frames: Optional[int] = None, + ff_mult: int = 4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if max_num_frames is not None else None) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if max_num_media is not None else None) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention( + dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): image features of shape (b, T, F, v, D) + + Returns: + torch.Tensor: shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if self.frame_embs is not None: + frame_embs = repeat( + self.frame_embs[:F], 'F d -> b T F v d', b=b, T=T, v=v) + x = x + frame_embs + x = rearrange(x, 'b T F v d -> b T (F v) d' + ) # flatten the frame and spatial dimensions + if self.media_time_embs is not None: + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, 'n d -> b T n d', b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +class MaskedCrossAttention(nn.Module): + """Masked cross attention layers. + + Args: + dim (int): Input text feature dimensions. + dim_visual (int): Input visual feature dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + only_attend_immediate_media (bool): Whether attend immediate media. + Defaults to True. + """ + + def __init__( + self, + *, + dim: int, + dim_visual: int, + dim_head: int = 64, + heads: int = 8, + only_attend_immediate_media: bool = True, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether for text to only attend to immediate preceding image + # or all previous images + self.only_attend_immediate_media = only_attend_immediate_media + + def forward(self, + x: torch.Tensor, + media: torch.Tensor, + media_locations: Optional[torch.Tensor] = None, + attend_previous: bool = True): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): text features of shape (B, T_txt, D_txt). + media (torch.Tensor): image features of shape + (B, T_img, n, D_img) where n is the dim of the latents. + media_locations (torch.Tensor, optional): boolean mask identifying + the media tokens in x of shape (B, T_txt). Defaults to None. + attend_previous (bool): If false, ignores immediately preceding + image and starts attending when following image. + Defaults to True. + """ + _, T_img, n = media.shape[:3] + h = self.heads + + x = self.norm(x) + + q = self.to_q(x) + media = rearrange(media, 'b t n d -> b (t n) d') + + k, v = self.to_kv(media).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + k = rearrange(k, 'b n (h d) -> b h n d', h=h) + v = rearrange(v, 'b n (h d) -> b h n d', h=h) + + q = q * self.scale + + sim = einsum('... i d, ... j d -> ... i j', q, k) + + if media_locations is not None: + # at each boolean of True, increment the time counter + # (relative to media time) + text_time = media_locations.cumsum(dim=-1) + media_time = torch.arange(T_img, device=x.device) + 1 + + if not attend_previous: + text_time[~media_locations] += 1 + # make sure max is still the number of images in the sequence + text_time[text_time > repeat( + torch.count_nonzero(media_locations, dim=1), + 'b -> b i', + i=text_time.shape[1], + )] = 0 + + # text time must equal media time if only attending to most + # immediate image otherwise, as long as text time is greater than + # media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge # noqa + + text_to_media_mask = mask_op( + rearrange(text_time, 'b i -> b 1 i 1'), + repeat(media_time, 'j -> 1 1 1 (j n)', n=n), + ) + sim = sim.masked_fill(~text_to_media_mask, + -torch.finfo(sim.dtype).max) + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + if media_locations is not None and self.only_attend_immediate_media: + # any text without a preceding media needs to have + # attention zeroed out + text_without_media_mask = text_time == 0 + text_without_media_mask = rearrange(text_without_media_mask, + 'b i -> b 1 i 1') + attn = attn.masked_fill(text_without_media_mask, 0.0) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class GatedCrossAttentionBlock(nn.Module): + """Gated cross attention layers. + + Args: + dim (int): Input text feature dimensions. + dim_visual (int): Input visual feature dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + ff_mult (int): Feed forward multiplier. Defaults to 4. + only_attend_immediate_media (bool): Whether attend immediate media. + Defaults to True. + """ + + def __init__( + self, + *, + dim: int, + dim_visual: int, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + only_attend_immediate_media: bool = True, + ): + super().__init__() + self.attn = MaskedCrossAttention( + dim=dim, + dim_visual=dim_visual, + dim_head=dim_head, + heads=heads, + only_attend_immediate_media=only_attend_immediate_media, + ) + self.attn_gate = nn.Parameter(torch.tensor([0.0])) + + self.ff = FeedForward(dim, mult=ff_mult) + self.ff_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, + x: torch.Tensor, + media: torch.Tensor, + media_locations: Optional[torch.Tensor] = None, + attend_previous: bool = True): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): text features of shape (B, T_txt, D_txt). + media (torch.Tensor): image features of shape + (B, T_img, n, D_img) where n is the dim of the latents. + media_locations (torch.Tensor, optional): boolean mask identifying + the media tokens in x of shape (B, T_txt). Defaults to None. + attend_previous (bool): If false, ignores immediately preceding + image and starts attending when following image. + Defaults to True. + """ + x = ( + self.attn( + x, + media, + media_locations=media_locations, + attend_previous=attend_previous, + ) * self.attn_gate.tanh() + x) + x = self.ff(x) * self.ff_gate.tanh() + x + + return x + + +class FlamingoLayer(nn.Module): + """Faminogo layers. + + Args: + gated_cross_attn_layer (nn.Module): Gated cross attention layer. + decoder_layer (nn.Module): Decoder layer. + """ + + def __init__(self, gated_cross_attn_layer: nn.Module, + decoder_layer: nn.Module): + super().__init__() + self.gated_cross_attn_layer = gated_cross_attn_layer + self.decoder_layer = decoder_layer + self.vis_x = None + self.media_locations = None + + def is_conditioned(self) -> bool: + """Check whether the layer is conditioned.""" + return self.vis_x is not None + + def condition_vis_x(self, vis_x): + """Set condition vision features.""" + self.vis_x = vis_x + + def condition_media_locations(self, media_locations): + """Set condition media locations.""" + self.media_locations = media_locations + + def condition_attend_previous(self, attend_previous): + """Set attend previous.""" + self.attend_previous = attend_previous + + def forward( + self, + lang_x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **decoder_layer_kwargs, + ): + """Forward function. + + Args: + lang_x (torch.Tensor): language inputs. + attention_mask (torch.Tensor, optional): text attention mask. + Defaults to None. + **decoder_layer_kwargs: Other decoder layer keyword arguments. + """ + if self.gated_cross_attn_layer is None: + return self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + + if self.vis_x is None: + raise ValueError('vis_x must be conditioned before forward pass') + + if self.media_locations is None: + raise ValueError( + 'media_locations must be conditioned before forward pass') + + lang_x = self.gated_cross_attn_layer( + lang_x, + self.vis_x, + media_locations=self.media_locations, + attend_previous=self.attend_previous, + ) + lang_x = self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + return lang_x diff --git a/mmpretrain/models/multimodal/flamingo/utils.py b/mmpretrain/models/multimodal/flamingo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1077e145a7daeeff1c769d837ec9c5aac0cf3d93 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Type + +from mmpretrain.registry import MODELS + + +class ExtendModule: + """Combine the base language model with adapter. This module will create a + instance from base with extended functions in adapter. + + Args: + base (object): Base module could be any object that represent + a instance of language model or a dict that can build the + base module. + adapter: (dict): Dict to build the adapter. + """ + + def __new__(cls, base: object, adapter: dict): + + if isinstance(base, dict): + base = MODELS.build(base) + + adapter_module = MODELS.get(adapter.pop('type')) + cls.extend_instance(base, adapter_module) + return adapter_module.extend_init(base, **adapter) + + @classmethod + def extend_instance(cls, base: object, mixin: Type[Any]): + """Apply mixins to a class instance after creation. + + Args: + base (object): Base module instance. + mixin: (Type[Any]): Adapter class type to mixin. + """ + base_cls = base.__class__ + base_cls_name = base.__class__.__name__ + base.__class__ = type( + base_cls_name, (mixin, base_cls), + {}) # mixin needs to go first for our forward() logic to work + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == '': + return obj + i = att.find('.') + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1:]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) + is equivalent to obj.a.b.c = val + """ + if '.' in att: + obj = getattr_recursive(obj, '.'.join(att.split('.')[:-1])) + setattr(obj, att.split('.')[-1], val) diff --git a/mmpretrain/models/multimodal/llava/__init__.py b/mmpretrain/models/multimodal/llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aef10d34d46fc3974744881c814068ae7d6f9357 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .llava import Llava +from .modules import LlavaLlamaForCausalLM + +__all__ = ['Llava', 'LlavaLlamaForCausalLM'] diff --git a/mmpretrain/models/multimodal/llava/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/llava/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e832f9f80a121c01ad7dd0770d2423d81e3e0a9c Binary files /dev/null and b/mmpretrain/models/multimodal/llava/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/llava/__pycache__/llava.cpython-38.pyc b/mmpretrain/models/multimodal/llava/__pycache__/llava.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41a5b1d0167bcaede83f1ea1be7a2aa154762679 Binary files /dev/null and b/mmpretrain/models/multimodal/llava/__pycache__/llava.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/llava/__pycache__/modules.cpython-38.pyc b/mmpretrain/models/multimodal/llava/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faa9ec2f6b155cf332f9dcf0adece217aa641cd4 Binary files /dev/null and b/mmpretrain/models/multimodal/llava/__pycache__/modules.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py new file mode 100644 index 0000000000000000000000000000000000000000..1c300fdcd05917ade4ee638a27e0ec79afbb4e63 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -0,0 +1,257 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ...utils import no_load_hf_pretrained_model +from .modules import LlavaLlamaForCausalLM + + +@MODELS.register_module() +class Llava(BaseModel): + """The LLaVA model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + prompt_tmpl (str): Prompt template for inference. + task (int): The task to perform prediction. + use_im_start_end (bool): Whether to use the im_start and im_end tokens + mm_vision_select_layer (int): The index from vision encoder output. + Defaults to -1. + use_mm_proj (bool): Whether to enable multi-modal projection. + Defaults to True. + load_lang_pretrained (bool): Whether to load the pretrained model of + language encoder. Defaults to False. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + im_patch_token = '' + im_start_token = '' + im_end_token = '' + + def __init__(self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + mm_hidden_size: int, + prompt_tmpl: str, + task: str = 'caption', + use_im_start_end: bool = False, + mm_vision_select_layer: int = -1, + use_mm_proj: bool = True, + generation_cfg: dict = dict(), + load_lang_pretrained: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Llava special tokens to the tokenizer + self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) + if use_im_start_end: + self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], + special_tokens=True) + + # Template to format the prompt input + self.prompt_tmpl = prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + + # init language encoder related modules + if load_lang_pretrained: + lang_encoder = MODELS.build(lang_encoder) + else: + with no_load_hf_pretrained_model(): + lang_encoder = MODELS.build(lang_encoder) + lang_encoder.resize_token_embeddings(len(self.tokenizer)) + + self.model = LlavaLlamaForCausalLM( + vision_encoder=vision_encoder, + lang_encoder=lang_encoder, + mm_hidden_size=mm_hidden_size, + use_mm_proj=use_mm_proj, + use_im_start_end=use_im_start_end, + im_start_token=self.tokenizer.convert_tokens_to_ids( + self.im_start_token), + im_end_token=self.tokenizer.convert_tokens_to_ids( + self.im_end_token), + im_patch_token=self.tokenizer.convert_tokens_to_ids( + self.im_patch_token), + mm_vision_select_layer=mm_vision_select_layer) + + self.generation_cfg = generation_cfg + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_ckpt_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): The input image tensor with different ndim + according to the inputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'predict': + return self.predict(images, data_samples) + elif mode == 'loss': + raise NotImplementedError + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **generation_cfg): + """Predict generation results from a batch of inputs. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **generation_cfg: Other keyword arguments accepted by the + ``generate`` method of :attr:`lang_encoder`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # generation_cfg in prediction should be dominant + generation_cfg = {**self.generation_cfg, **generation_cfg} + + input_text = self.preprocess_text(data_samples, device=images.device) + + outputs = self.model.generate( + input_text.input_ids, + attention_mask=input_text.attention_mask, + eos_token_id=self.tokenizer.eos_token_id, + images=images, + **generation_cfg) + + # remove prefix + outputs = outputs[:, len(input_text.input_ids[0]):] + + return self.post_process(outputs, data_samples) + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + final_prompt = self.prompt_tmpl.format(**sample.to_dict()) + prompts.append(final_prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = output + elif self.task == 'vqa': + data_sample.pred_answer = output + + return data_samples + + @staticmethod + def _load_ckpt_hook(module, incompatible_keys): + """Avoid warning missing keys except lang_encoder keys.""" + for key in list(incompatible_keys.missing_keys): + if re.match('model.vision_tower', key): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..afa6eefadcbd73f630d8c842c80b83f229216c97 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -0,0 +1,238 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +DEFAULT_IMAGE_TOKEN = '' +DEFAULT_IMAGE_PATCH_TOKEN = '' +DEFAULT_IM_START_TOKEN = '' +DEFAULT_IM_END_TOKEN = '' + + +class LlavaLlamaForCausalLM(PreTrainedModel): + + def __init__(self, + vision_encoder, + lang_encoder, + mm_hidden_size, + use_im_start_end=True, + use_mm_proj=True, + im_start_token: Optional[int] = None, + im_end_token: Optional[int] = None, + im_patch_token: Optional[int] = None, + mm_vision_select_layer: int = -1): + super().__init__(lang_encoder.config) + self.vision_tower = vision_encoder + self.lang_encoder = lang_encoder + + self.use_im_start_end = use_im_start_end + self.im_start_token = im_start_token + self.im_end_token = im_end_token + self.im_patch_token = im_patch_token + self.mm_hidden_size = mm_hidden_size + self.mm_vision_select_layer = mm_vision_select_layer + self.lang_hidden_size = lang_encoder.config.hidden_size + + if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'): + mm_projector = nn.Linear(self.mm_hidden_size, + self.lang_hidden_size) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif not use_mm_proj: + self.lang_encoder.model.add_module('mm_projector', nn.Identity()) + + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of + # (dec_features, layer_state, dec_hidden, dec_attn) + if inputs_embeds is None: + inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids) + + inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds, + images) + + return self.lang_encoder( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use + # them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + 'images': kwargs.get('images', None), + }) + return model_inputs + + def forward_vision_tower( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + images: Union[torch.FloatTensor, list, None] = None, + ): + if self.use_im_start_end: + assert self.im_start_token is not None + assert self.im_end_token is not None + if images is not None: + assert self.im_patch_token is not None + + if self.vision_tower is None or images is None or ( + input_ids.shape[1] == 1 and not self.training): + return inputs_embeds + + with torch.no_grad(): + if isinstance(images, (list, tuple)): + # variable length images + image_features = [] + for image in images: + feats = self.vision_tower(image.unsqueeze(0)) + image_feature = feats[self.mm_vision_select_layer][:, 1:] + image_features.append(image_feature) + else: + feats = self.vision_tower(images) + image_features = feats[self.mm_vision_select_layer][:, 1:] + + mm_projector = self.lang_encoder.model.mm_projector + if isinstance(images, (list, tuple)): + image_features = [ + mm_projector(image_feature)[0] + for image_feature in image_features + ] + else: + image_features = mm_projector(image_features) + + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = mm_projector(dummy_image_features) + + new_input_embeds = [] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids != self.im_patch_token).all(): + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + ( + 0. * dummy_image_features).sum() + new_input_embeds.append(cur_input_embeds) + cur_image_idx += 1 + continue + if self.use_im_start_end: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == self.im_start_token).sum() != ( + cur_input_ids == self.im_end_token).sum(): + raise ValueError('The number of image start tokens and ' + 'image end tokens should be the same.') + image_start_tokens = torch.where( + cur_input_ids == self.im_start_token)[0] + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device) + num_patches = cur_image_features.shape[0] + if cur_input_ids[image_start_token_pos + num_patches + + 1] != self.im_end_token: + raise ValueError('The image end token should follow ' + 'the image start token.') + cur_new_input_embeds = torch.cat( + (cur_input_embeds[:image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[image_start_token_pos + num_patches + + 1:]), + dim=0) + cur_image_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == self.im_patch_token).sum() != num_patches: + print(f'Debug: num_patches: {num_patches}') + raise ValueError( + 'The number of image patch tokens should ' + 'be the same as the number of image patches.') + masked_indices = torch.where( + cur_input_ids == self.im_patch_token)[0] + mask_index_start = masked_indices[0] + if (masked_indices != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype)).any(): + raise ValueError( + 'The image patch tokens should be consecutive.') + cur_new_input_embeds = torch.cat( + (cur_input_embeds[:mask_index_start], cur_image_features, + cur_input_embeds[mask_index_start + num_patches:]), + dim=0) + new_input_embeds.append(cur_new_input_embeds) + cur_image_idx += 1 + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return inputs_embeds + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/minigpt4/__init__.py b/mmpretrain/models/multimodal/minigpt4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5358bb1377ee6da7d848c06f3a249493645cdbf7 --- /dev/null +++ b/mmpretrain/models/multimodal/minigpt4/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .minigpt4 import MiniGPT4 + +__all__ = ['MiniGPT4'] diff --git a/mmpretrain/models/multimodal/minigpt4/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/minigpt4/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf8fcfd54bb7d1b0e3faa5be5b5bae8829d071d Binary files /dev/null and b/mmpretrain/models/multimodal/minigpt4/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/minigpt4/__pycache__/minigpt4.cpython-38.pyc b/mmpretrain/models/multimodal/minigpt4/__pycache__/minigpt4.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0ba76f589d848b3734b4b9c234ad81af993785 Binary files /dev/null and b/mmpretrain/models/multimodal/minigpt4/__pycache__/minigpt4.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py new file mode 100644 index 0000000000000000000000000000000000000000..d23203603ec7767cb5ca494865f588f9bcecff0f --- /dev/null +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -0,0 +1,381 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +import re +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.logging import MMLogger +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class MiniGPT4(BaseModel): + """The multi-modality model of MiniGPT-4. + + The implementation of `MiniGPT-4 `_. + Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py + + Args: + vision_encoder (dict): The config for vision encoder. + q_former_model (dict): The config for Qformer. + lang_encoder (dict): The config for language model. + tokenizer (dict): The config for tokenizer. + task (str): To define the task, which control the processing of text. + Defaults to 'caption'. + freeze_vit (bool): Freeze the training of ViT. Defaults to True. + freeze_q_former (bool): Freeze the training of Qformer. Defaults to + True. + num_query_token (int): Number of query tokens of Qformer. Defaults to + 32. + prompt_template (str): Prompt template of the model. Defaults to + '###Human: {} ###Assistant: '. + raw_prompts (list): Prompts for training. Defaults to None. + max_txt_len (int): Max token length while doing tokenization. Defaults + to 32. + end_sym (str): Ended symbol of the sequence. Defaults to '\n'. + generation_cfg (dict): The config of text generation. Defaults to + dict(). + data_preprocessor (:obj:`BaseDataPreprocessor`): Used for + pre-processing data sampled by dataloader to the format accepted by + :meth:`forward`. Defaults to None. + init_cfg (dict): Initialization config dict. Defaults to None. + """ # noqa + + def __init__(self, + vision_encoder: dict, + q_former_model: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + freeze_vit: bool = True, + freeze_q_former: bool = True, + num_query_token: int = 32, + prompt_template: str = '###Human: {} ###Assistant: ', + raw_prompts: Optional[list] = None, + max_txt_len: int = 32, + end_sym: str = '\n', + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.task = task + logger = MMLogger.get_current_instance() + + # build vision model + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims) + + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint(self.vision_encoder, vision_encoder_weight) + if freeze_vit: + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + else: + logger.warning('Please check `frozen_stages` in the dict of' + '`vision_encoder`. Also set it to be -1 if do not' + 'freeze ViT.') + + # build Qformer + q_former_model_weight = q_former_model.pop('pretrained', None) + self.q_former = MODELS.build(q_former_model) + self.q_former.cls = None + self.q_former.bert.embeddings.word_embeddings = None + self.q_former.bert.embeddings.position_embeddings = None + for layer in self.q_former.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, self.q_former.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, std=self.q_former.config.initializer_range) + + if q_former_model_weight is not None: + from mmengine.runner.checkpoint import CheckpointLoader + state_dict = CheckpointLoader.load_checkpoint( + q_former_model_weight)['state_dict'] + self.load_state_dict(state_dict, strict=False) + + if freeze_q_former: + for name, param in self.q_former.named_parameters(): + param.requires_grad = False + self.q_former.eval() + self.query_tokens.requires_grad = False + + # build language model + self.llama_tokenizer = TOKENIZER.build(tokenizer) + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + + self.llama_model = MODELS.build(lang_encoder) + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + + # build linear projection layer + self.llama_proj = nn.Linear(self.q_former.config.hidden_size, + self.llama_model.config.hidden_size) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] + + # set prompts + if raw_prompts is not None: + filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts + if '' in raw_prompt + ] + self.prompt_list = [ + prompt_template.format(p) for p in filted_prompts + ] + else: + self.prompt_list = [] + + # update generation configs + self.generation_cfg = dict( + max_new_tokens=300, + num_beams=1, + do_sample=True, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1.0, + temperature=1.0, + **generation_cfg) + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_llama_proj_hook) + + def encode_img(self, + images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """The function to encode the images.""" + device = images.device + x = self.vision_encoder(images)[0] + image_embeds = self.ln_vision(x).to(device) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.q_former.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + atts_llama = torch.ones( + inputs_llama.size()[:-1], dtype=torch.long).to(images.device) + return inputs_llama, atts_llama + + def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, + prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: + """The function to wrap the image and prompt. + + Currently, the function only supports applying one prompt to all input + images in the one batch. + + Args: + img_embeds (torch.Tensor): The embedding of the input images. + atts_img (torch.Tensor): Attention map of the image embeddings. + prompt (str): The prompt of the batch data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. + """ + if prompt: + batch_size = img_embeds.shape[0] + p_before, p_after = prompt.split('') + p_before_tokens = self.llama_tokenizer( + p_before, return_tensors='pt', + add_special_tokens=False).to(img_embeds.device) + p_after_tokens = self.llama_tokenizer( + p_after, return_tensors='pt', + add_special_tokens=False).to(img_embeds.device) + p_before_embeds = self.llama_model.model.embed_tokens( + p_before_tokens.input_ids).expand(batch_size, -1, -1) + p_after_embeds = self.llama_model.model.embed_tokens( + p_after_tokens.input_ids).expand(batch_size, -1, -1) + wrapped_img_embeds = torch.cat( + [p_before_embeds, img_embeds, p_after_embeds], dim=1) + wrapped_atts_img = atts_img[:, :1].expand( + -1, wrapped_img_embeds.shape[1]) + return wrapped_img_embeds, wrapped_atts_img + else: + return img_embeds, atts_img + + def loss(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None) -> dict: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + img_embeds, atts_img = self.encode_img(images) + + if self.task == 'caption' and self.prompt_list: + prompt = random.choice(self.prompt_list) + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, + prompt) + + self.llama_tokenizer.padding_side = 'right' + + text = [t + self.end_sym for t in data_samples['text_input']] + + to_regress_tokens = self.llama_tokenizer( + text, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False).to(images.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, + -100) + + empty_targets = ( + torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], + dtype=torch.long).to(images.device).fill_( + -100) # plus one for bos + ) + targets = torch.cat([empty_targets, targets], dim=1) + + batch_size = img_embeds.shape[0] + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device + ) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + atts_bos = atts_img[:, :1] + + to_regress_embeds = self.llama_model.model.embed_tokens( + to_regress_tokens.input_ids) + inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], + dim=1) + attention_mask = torch.cat( + [atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + return dict(loss=loss) + + def predict( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None + ) -> List[DataSample]: + + with torch.no_grad(): + img_embeds, atts_img = self.encode_img(images) + + if self.task == 'caption' and self.prompt_list: + prompt = random.choice(self.prompt_list) + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, + prompt) + + batch_size = img_embeds.shape[0] + bos = torch.ones( + [batch_size, 1], dtype=torch.long, + device=img_embeds.device) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) + + outputs = self.llama_model.generate( + inputs_embeds=inputs_embeds, + eos_token_id=self.end_token_id, + **self.generation_cfg) + + return self.post_process(outputs, data_samples) + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.llama_tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + if self.task == 'caption': + output = output.split('###')[0] + output = output.split('Assistant:')[-1].strip() + data_sample.pred_caption = output + else: + # raw output + data_sample.pred_output = output + return data_samples + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @staticmethod + def _load_llama_proj_hook(module, incompatible_keys): + """Avoid warning missing keys except LLaMA projection keys.""" + proj_patterns = [ + 'vision_encoder.*', + 'ln_vision.*', + 'q_former.*', + 'query_tokens', + 'llama_model.*', + ] + for key in list(incompatible_keys.missing_keys): + if any(re.match(pattern, key) for pattern in proj_patterns): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/ofa/__init__.py b/mmpretrain/models/multimodal/ofa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb3f45f09b757304bfca3de2a94d217ff78d8d4 --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ofa import OFA +from .ofa_modules import OFADecoder, OFAEncoder, OFAEncoderDecoder + +__all__ = ['OFAEncoderDecoder', 'OFA', 'OFAEncoder', 'OFADecoder'] diff --git a/mmpretrain/models/multimodal/ofa/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/ofa/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9615002c111a583132779cb8569bbac89fb0411 Binary files /dev/null and b/mmpretrain/models/multimodal/ofa/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/ofa/__pycache__/ofa.cpython-38.pyc b/mmpretrain/models/multimodal/ofa/__pycache__/ofa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a450dfa5b7ff46b5da923a4fc70b47a62b9fa5d Binary files /dev/null and b/mmpretrain/models/multimodal/ofa/__pycache__/ofa.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/ofa/__pycache__/ofa_modules.cpython-38.pyc b/mmpretrain/models/multimodal/ofa/__pycache__/ofa_modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bd463d34726fd886ce86262cbab7deee3c49551 Binary files /dev/null and b/mmpretrain/models/multimodal/ofa/__pycache__/ofa_modules.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/ofa/ofa.py b/mmpretrain/models/multimodal/ofa/ofa.py new file mode 100644 index 0000000000000000000000000000000000000000..e15787a60d66ac56308b320cdd73a7703a2a29bc --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/ofa.py @@ -0,0 +1,320 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import string +from collections import defaultdict +from functools import partial +from typing import Optional, Union + +import mmengine +import torch +from mmengine.model import BaseModel + +from mmpretrain.datasets import CleanCaption +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .ofa_modules import OFAEncoderDecoder + + +class TreeNode(): + + def __init__(self): + self.child = defaultdict(TreeNode) + + +class Trie: + + def __init__(self, eos): + self.root = TreeNode() + self.eos = eos + + def insert(self, word): + cur = self.root + for c in word: + cur = cur.child[c] + + def get_next_layer(self, word): + cur = self.root + for c in word: + cur = cur.child.get(c) + if cur is None: + return [self.eos] + return list(cur.child.keys()) + + +def apply_constraint( + input_ids: torch.Tensor, + logits: torch.Tensor, + decoder_prompts: Optional[list], + num_beams: int, + constraint_trie: Trie = None, +): + if decoder_prompts is None and constraint_trie is None: + return logits + + mask = logits.new_zeros(logits[:, -1, :].size(), dtype=torch.bool) + input_ids = input_ids.view(-1, num_beams, input_ids.shape[-1]) + for batch_id, beam_sent in enumerate(input_ids): + for beam_id, sent in enumerate(beam_sent): + if decoder_prompts is None: + prompt_len = 0 + else: + prompt_len = len(decoder_prompts[batch_id]) + + if sent.size(0) - 1 < prompt_len: + allowed_tokens = [decoder_prompts[batch_id][sent.size(0) - 1]] + mask[batch_id * num_beams + beam_id, allowed_tokens] = True + elif constraint_trie is not None: + answer_tokens = [0] + sent[prompt_len + 1:].tolist() + allowed_tokens = constraint_trie.get_next_layer(answer_tokens) + mask[batch_id * num_beams + beam_id, allowed_tokens] = True + else: + mask[batch_id * num_beams + beam_id, :] = True + logits[:, -1, :].masked_fill_(~mask, float('-inf')) + return logits + + +@MODELS.register_module() +class OFA(BaseModel): + """The OFA model for multiple tasks. + + Args: + encoder_cfg (dict): The config of the encoder, accept the keyword + arguments of :class:`OFAEncoder`. + decoder_cfg (dict): The config of the decoder, accept the keyword + arguments of :class:`OFADecoder`. + vocab_size (int): The size of the vocabulary. + embedding_dim (int): The embedding dimensions of both the encoder + and the decoder. + tokenizer (dict | PreTrainedTokenizer): The tokenizer to encode + the text. + task (str): The task name, supported tasks are "caption", "vqa" and + "refcoco". + prompt (str, optional): The prompt template for the following tasks, + If None, use default prompt: + + - **caption**: ' what does the image describe?' + - **refcoco**: ' which region does the text " {} " describe?' + + Defaults to None + ans2label (str | Sequence | None): The answer to label mapping for + the vqa task. If a string, it should be a pickle or json file. + The sequence constrains the output answers. Defaults to None, + which means no constraint. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of :class:`~transformers.GenerationConfig`. + Defaults to an empty dict. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. See :class: + `MultiModalDataPreprocessor` for more details. Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + support_tasks = {'caption', 'vqa', 'refcoco'} + + def __init__( + self, + encoder_cfg, + decoder_cfg, + vocab_size, + embedding_dim, + tokenizer, + task, + prompt=None, + ans2label: Union[dict, str, None] = None, + generation_cfg=dict(), + data_preprocessor: Optional[dict] = None, + init_cfg=None, + ): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if isinstance(tokenizer, dict): + self.tokenizer = TOKENIZER.build(tokenizer) + else: + self.tokenizer = tokenizer + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + + self.prompt = prompt + self.task = task + + if isinstance(ans2label, str): + self.ans2label = mmengine.load(ans2label) + else: + self.ans2label = ans2label + + if self.task == 'vqa' and self.ans2label is not None: + self.constraint_trie = Trie(eos=self.tokenizer.eos_token_id) + answers = [f' {answer}' for answer in self.ans2label] + answer_tokens = self.tokenizer(answers, padding=False) + for answer_token in answer_tokens['input_ids']: + self.constraint_trie.insert(answer_token) + else: + self.constraint_trie = None + + generation_cfg = { + 'num_beams': 5, + 'max_new_tokens': 20, + 'no_repeat_ngram_size': 3, + **generation_cfg, + } + self.model = OFAEncoderDecoder( + encoder_cfg=encoder_cfg, + decoder_cfg=decoder_cfg, + padding_idx=self.tokenizer.pad_token_id, + vocab_size=vocab_size, + embedding_dim=embedding_dim, + generation_cfg=generation_cfg, + ) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict( + self, + images, + data_samples=None, + post_process=True, + **generation_config, + ): + text_tokens = self.preprocess_text(data_samples, images.size(0), + images.device) + + if 'images_mask' in data_samples[0]: + images_mask = torch.tensor([ + sample.get('images_mask') for sample in data_samples + ]).bool().to(images.device) + else: + images_mask = None + + num_beams = generation_config.get( + 'num_beams', getattr(self.model.generation_config, 'num_beams')) + decoder_prompts = self.get_decoder_prompts(data_samples) + constrain_fn = partial( + apply_constraint, + constraint_trie=self.constraint_trie, + decoder_prompts=decoder_prompts, + num_beams=num_beams, + ) + + outputs = self.model.generate( + input_ids=text_tokens, + images=images, + images_mask=images_mask, + constrain_fn=constrain_fn, + **generation_config, + ) + + if decoder_prompts is not None: + # Remove the prefix decoder prompt. + for prompt_ids, token in zip(decoder_prompts, outputs): + token[1:len(prompt_ids) + 1] = self.tokenizer.pad_token_id + + if post_process: + return self.post_process(outputs, data_samples) + else: + return outputs + + def get_decoder_prompts(self, data_samples): + decoder_prompts = [] + if 'decoder_prompt' not in data_samples[0]: + return None + for sample in data_samples: + prompt = ' ' + sample.get('decoder_prompt') + prompt_ids = self.tokenizer(prompt, add_special_tokens=False) + prompt_ids = prompt_ids['input_ids'] + decoder_prompts.append(prompt_ids) + return decoder_prompts + + def preprocess_text(self, data_samples, batch_size, device): + if self.task == 'caption': + prompt = self.prompt or ' what does the image describe?' + prompts = [prompt] * batch_size + prompts = self.tokenizer(prompts, return_tensors='pt') + return prompts.input_ids.to(device) + elif self.task == 'vqa': + prompts = [] + for sample in data_samples: + assert 'question' in sample + prompt = ' ' + sample.get('question') + prompts.append(prompt) + prompts = self.tokenizer( + prompts, return_tensors='pt', padding=True) + return prompts.input_ids.to(device) + elif self.task == 'refcoco': + prompt_template = self.prompt or \ + ' which region does the text " {} " describe?' + prompts = [] + for sample in data_samples: + assert 'text' in sample + prompt = prompt_template.format(sample.get('text')) + prompts.append(prompt) + prompts = self.tokenizer( + prompts, return_tensors='pt', padding=True) + return prompts.input_ids.to(device) + + def post_process(self, outputs, data_samples): + + out_data_samples = [] + if data_samples is None: + data_samples = [None] * outputs.size(0) + + for data_sample, token in zip(data_samples, outputs): + if data_sample is None: + data_sample = DataSample() + + if self.task == 'caption': + text = self.tokenizer.decode(token, skip_special_tokens=True) + text = CleanCaption( + lowercase=False, + remove_chars=string.punctuation).clean(text) + data_sample.pred_caption = text + elif self.task == 'vqa': + text = self.tokenizer.decode(token, skip_special_tokens=True) + data_sample.pred_answer = text.strip() + elif self.task == 'refcoco': + bbox = token[1:5] - self.tokenizer.bin_offset + # During training, the bbox is normalized by 512. It's related + # to the `max_image_size` config in the official repo. + bbox = bbox / self.tokenizer.num_bins * 512 + scale_factor = data_sample.get('scale_factor', (1, 1)) + bbox[0::2] /= scale_factor[0] + bbox[1::2] /= scale_factor[1] + data_sample.pred_bboxes = bbox.unsqueeze(0) + if 'gt_bboxes' in data_sample: + gt_bboxes = bbox.new_tensor(data_sample.gt_bboxes) + gt_bboxes[:, 0::2] /= scale_factor[0] + gt_bboxes[:, 1::2] /= scale_factor[1] + data_sample.gt_bboxes = gt_bboxes + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/ofa/ofa_modules.py b/mmpretrain/models/multimodal/ofa/ofa_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1c79049b617685ad9d5ab244ed09c56e70b348fd --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/ofa_modules.py @@ -0,0 +1,1612 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.utils import digit_version +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, ModelOutput, Seq2SeqLMOutput) +from transformers.modeling_utils import (GenerationConfig, GenerationMixin, + PretrainedConfig) + +from mmpretrain.registry import MODELS +from ...backbones.resnet import Bottleneck, ResNet + +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +def make_token_bucket_position(bucket_size, max_position=1024): + context_pos = torch.arange(max_position, dtype=torch.long)[:, None] + memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] + relative_pos = context_pos - memory_pos + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), + mid - 1, torch.abs(relative_pos)) + log_pos = torch.ceil( + torch.log(abs_pos / mid) / math.log( + (max_position - 1) / mid) * (mid - 1)) + mid + log_pos = log_pos.int() + bucket_pos = torch.where(abs_pos.le(mid), relative_pos, + log_pos * sign).long() + return bucket_pos + bucket_size - 1 + + +def make_image_bucket_position(bucket_size, num_relative_distance): + coords_h = torch.arange(bucket_size) + coords_w = torch.arange(bucket_size) + # (2, h, w) + coords = torch.stack(torch_meshgrid([coords_h, coords_w])) + # (2, h*w) + coords_flatten = torch.flatten(coords, 1) + # (2, h*w, h*w) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + # (h*w, h*w, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += bucket_size - 1 + relative_coords[:, :, 0] *= 2 * bucket_size - 1 + relative_position_index = torch.zeros( + size=(bucket_size * bucket_size + 1, ) * 2, + dtype=relative_coords.dtype) + # (h*w, h*w) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + """Make causal mask used for uni-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float('-inf')) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from ``[B, L_s]`` to ``[B, 1, L_t, L_s]``. + + Where ``B`` is batch_size, `L_s`` is the source sequence length, and + ``L_t`` is the target sequence length. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +class MultiheadAttention(BaseModule): + """Multi-head Attention Module for OFA. + + Args: + embedding_dim (int): The embedding dimension of query. + num_heads (int): Parallel attention heads. + kdim (int, optional): The embedding dimension of key. + Defaults to None, which means the same as the `embedding_dim`. + vdim (int, optional): The embedding dimension of value. + Defaults to None, which means the same as the `embedding_dim`. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + scale_factor (float): The scale of qk will be + ``(head_dim * scale_factor) ** -0.5``. Defaults to 1. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embedding_dim, + num_heads, + kdim=None, + vdim=None, + attn_drop=0., + scale_factor=1., + qkv_bias=True, + proj_bias=True, + scale_heads=False, + init_cfg=None): + super(MultiheadAttention, self).__init__(init_cfg=init_cfg) + + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.kdim = kdim or embedding_dim + self.vdim = vdim or embedding_dim + + self.head_dim = embedding_dim // num_heads + self.scale = (self.head_dim * scale_factor)**-0.5 + + self.q_proj = nn.Linear(embedding_dim, embedding_dim, bias=qkv_bias) + self.k_proj = nn.Linear(self.kdim, embedding_dim, bias=qkv_bias) + self.v_proj = nn.Linear(self.vdim, embedding_dim, bias=qkv_bias) + self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=proj_bias) + + self.attn_drop = nn.Dropout(p=attn_drop) + + if scale_heads: + self.c_attn = nn.Parameter(torch.ones(num_heads)) + else: + self.c_attn = None + + def forward( + self, + query, + key_value=None, + attn_mask=None, + attn_bias=None, + past_key_value=None, + output_attentions=False, + ): + B, _, C = query.shape + assert C == self.head_dim * self.num_heads + + is_cross_attention = key_value is not None + if key_value is None: + key_value = query + + # (B, L, C) -> (B, num_heads, L, head_dims) + q = self.q_proj(query).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + + if is_cross_attention and past_key_value is not None: + # Reuse key and value in cross_attentions + k, v = past_key_value + else: + k = self.k_proj(key_value).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + v = self.v_proj(key_value).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + if past_key_value is not None: + past_key, past_value = past_key_value + k = torch.cat([past_key, k], dim=2) + v = torch.cat([past_value, v], dim=2) + + past_key_value = (k, v) + + attn_weights = q @ k.transpose(-2, -1) * self.scale + + if attn_bias is not None: + src_len = k.size(2) + attn_weights[:, :, -src_len:] += attn_bias[:, :, -src_len:] + + if attn_mask is not None: + attn_weights += attn_mask + attn_weights = torch.softmax(attn_weights, dim=-1) + attn = self.attn_drop(attn_weights) @ v + + if self.c_attn is not None: + attn = torch.einsum('bhlc,h->bhlc', attn, self.c_attn) + + # (B, num_heads, L, head_dims) -> (B, L, C) + attn = attn.transpose(1, 2).reshape(B, -1, self.embedding_dim) + attn = self.out_proj(attn) + + if output_attentions: + return attn, attn_weights, past_key_value + else: + return attn, None, past_key_value + + +@MODELS.register_module(force=True) +class OFAResNet(ResNet): + """ResNet module for OFA. + + The ResNet in OFA has only three stages. + """ + arch_settings = { + 50: (Bottleneck, (3, 4, 6)), + 101: (Bottleneck, (3, 4, 23)), + 152: (Bottleneck, (3, 8, 36)), + } + + def __init__(self, depth, *args, **kwargs): + super().__init__( + depth=depth, + *args, + num_stages=3, + out_indices=(2, ), + dilations=(1, 1, 1), + strides=(1, 2, 2), + **kwargs) + + +@dataclass +class OFAEncoderOutput(ModelOutput): + """OFA encoder outputs. + + Args: + last_hidden_state (torch.tensor): The hidden-states of the output at + the last layer of the model. The shape is (B, L, C). + hidden_states (Tuple[torch.tensor]): The initial embedding and the + output of each layer. The shape of every item is (B, L, C). + attentions (Tuple[torch.tensor]): The attention weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. The shape of every item is + (B, num_heads, L, L). + position_embedding (torch.tensor): The positional embeddings of the + inputs. The shape is (B, L, C). + """ + + last_hidden_state: torch.FloatTensor = None + padding_mask: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + position_embedding: Optional[torch.FloatTensor] = None + + +class OFAEncoderLayer(nn.Module): + """OFAEncoder layer block.""" + + def __init__(self, + embedding_dim, + num_heads, + dropout_rate=0., + drop_path_rate=0., + attn_drop=0., + act_drop=0., + scale_factor=2., + mlp_ratio=4., + scale_heads=True, + normformer=True, + pre_norm=True, + act_cfg=dict(type='GELU')): + super().__init__() + self.embedding_dim = embedding_dim + self.pre_norm = pre_norm + + self.attn = MultiheadAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + mid_channels = int(embedding_dim * mlp_ratio) + self.fc1 = nn.Linear(embedding_dim, mid_channels) + self.fc2 = nn.Linear(mid_channels, embedding_dim) + self.act = MODELS.build(act_cfg) + self.act_drop = nn.Dropout( + act_drop) if act_drop > 0. else nn.Identity() + + # LayerNorm between attention block and ffn block. + self.attn_ln = nn.LayerNorm(embedding_dim) + self.ffn_ln = nn.LayerNorm(embedding_dim) + + # Extra LayerNorm + self.normformer = normformer + if self.normformer: + self.attn_mid_ln = nn.LayerNorm(embedding_dim) + self.ffn_mid_ln = nn.LayerNorm(mid_channels) + + self.dropout = nn.Dropout(dropout_rate) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, + x, + attention_mask=None, + attn_bias=None, + output_attentions=False): + """Forward the encoder layer. + + Args: + x (torch.tensor): The input to the layer of shape ``(B, L, C)``. + attention_mask (torch.Tensor, optional): The attention mask of size + ``(B, 1, L, L)``, where padding elements are indicated by very + large negative values. Defaults to None. + attn_bias (torch.tensor, optional): The bias for positional + information. Defaults to None. + output_attentions (bool): Whether to return the attentions tensors + of the attention layer. + + Returns: + List[torch.tensor]: The first element is the encoded output of + shape ``(B, L, C)``. And the second element is the output + attentions if ``output_attentions=True``. + """ + residual = x + + # Attention block + if self.pre_norm: + x = self.attn_ln(x) + x, attn_weights, _ = self.attn( + query=x, + attn_mask=attention_mask, + attn_bias=attn_bias, + output_attentions=output_attentions) + if self.normformer: + x = self.attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.attn_ln(x) + + residual = x + + # FFN block + if self.pre_norm: + x = self.ffn_ln(x) + x = self.act(self.fc1(x)) + x = self.act_drop(x) + if self.normformer: + x = self.ffn_mid_ln(x) + x = self.fc2(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.ffn_ln(x) + + if output_attentions: + return [x, attn_weights] + else: + return [x] + + +class OFADecoderLayer(nn.Module): + """OFADecoder layer block.""" + + def __init__(self, + embedding_dim, + num_heads, + dropout_rate=0., + drop_path_rate=0., + attn_drop=0., + act_drop=0., + scale_factor=2., + mlp_ratio=4., + encoder_embed_dim=None, + scale_heads=True, + normformer=True, + pre_norm=True, + act_cfg=dict(type='GELU')): + super().__init__() + self.embedding_dim = embedding_dim + self.pre_norm = pre_norm + + self.self_attn = MultiheadAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + self.cross_attn = MultiheadAttention( + embedding_dim=embedding_dim, + kdim=encoder_embed_dim, + vdim=encoder_embed_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + mid_channels = int(embedding_dim * mlp_ratio) + self.fc1 = nn.Linear(embedding_dim, mid_channels) + self.fc2 = nn.Linear(mid_channels, embedding_dim) + self.act = MODELS.build(act_cfg) + self.act_drop = nn.Dropout( + act_drop) if act_drop > 0. else nn.Identity() + + # LayerNorm between attention block and ffn block. + self.self_attn_ln = nn.LayerNorm(embedding_dim) + self.cross_attn_ln = nn.LayerNorm(embedding_dim) + self.ffn_ln = nn.LayerNorm(embedding_dim) + + # Extra LayerNorm + self.normformer = normformer + if self.normformer: + self.self_attn_mid_ln = nn.LayerNorm(embedding_dim) + self.cross_attn_mid_ln = nn.LayerNorm(embedding_dim) + self.ffn_mid_ln = nn.LayerNorm(mid_channels) + + self.dropout = nn.Dropout(dropout_rate) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward( + self, + x, + attention_mask=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[List[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + self_attn_bias: Optional[torch.Tensor] = None, + cross_attn_bias: Optional[torch.Tensor] = None, + ): + """Forward the decoder layer. + + Args: + x (torch.tensor): The input to the layer of shape ``(B, L, C)``. + attention_mask (torch.Tensor, optional): The attention mask of size + ``(B, 1, L, L)``, where padding elements are indicated by very + large negative values. Defaults to None. + encoder_hidden_states (torch.Tensor, optional): The cross attention + input to the layer of size ``(B, L, C)``. Defaults to None. + encoder_attention_mask (torch.Tensor, optional): The cross + attention mask where padding elements are indicated by very + large negative values. Defaults to None. + past_key_value (Tuple[torch.tensor], optional): The cached past key + and value projection states. Defaults to none. + output_attentions (bool): whether to return the attentions tensors + of all attention layers. Defaults to False. + use_cache (bool, optional): Whether to use cache. + Defaults to False. + self_attn_bias (torch.Tensor, optional): The self attention bias + for positional information. Defaults to None. + cross_attn_bias (torch.Tensor, optional): The cross attention bias + for positional information. Defaults to None. + + Returns: + List[torch.tensor]: The first element is the encoded output of + shape ``(B, L, C)``. The following two elements can be the output + self-attentions and cross-attentions if ``output_attentions=True``. + The following one element can be the cached past key and value + projection states. + """ + residual = x + + if past_key_value is not None: + self_past_key_value = past_key_value[:2] + cross_past_key_value = past_key_value[2:] + else: + self_past_key_value, cross_past_key_value = None, None + + # Self-Attention block + if self.pre_norm: + x = self.self_attn_ln(x) + x, self_attn_weights, present_key_value = self.self_attn( + query=x, + past_key_value=self_past_key_value, + attn_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=self_attn_bias, + ) + if self.normformer: + x = self.self_attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.self_attn_ln(x) + + # Cross-Attention block + if encoder_hidden_states is not None: + residual = x + if self.pre_norm: + x = self.cross_attn_ln(x) + x, cross_attn_weights, cross_key_value = self.cross_attn.forward( + query=x, + key_value=encoder_hidden_states, + attn_mask=encoder_attention_mask, + past_key_value=cross_past_key_value, + output_attentions=output_attentions, + attn_bias=cross_attn_bias) + if self.normformer: + x = self.cross_attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.cross_attn_ln(x) + + present_key_value = present_key_value + cross_key_value + + residual = x + + # FFN block + if self.pre_norm: + x = self.ffn_ln(x) + x = self.act(self.fc1(x)) + x = self.act_drop(x) + if self.normformer: + x = self.ffn_mid_ln(x) + x = self.fc2(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.ffn_ln(x) + + outputs = [x] + + if output_attentions: + outputs.extend([self_attn_weights, cross_attn_weights]) + + if use_cache: + outputs.append(present_key_value) + + return outputs + + +class OFAEncoder(BaseModule): + """The encoder module of OFA. + + Args: + embed_tokens (nn.Embedding): The embedding module to embed the + input tokens. + embed_images (dict | nn.Module): The module to embed the input + images into features. The output number of channels should + be 1024. + num_layers (int): The number of encoder layers. Defaults to 6. + num_heads (int): The number of heads of attention. Defaults to 12. + dropout_rate (float): The prob of dropout for embedding and + transformer layers. Defaults to 0. + drop_path_rate (float): The prob of droppath for transformer layers. + Defaults to 0. + max_source_positions (int): The maximum length of the input tokens. + Defaults to 1024. + token_bucket_size (int): The token bucket size, it's used as the + maximum relative position index in relative position embedding + of input tokens. Defaults to 256. + image_bucket_size (int): The image bucket size, it's used to generate + the image relative position embedding table. It should be larger + than the shape of image feature map. Defaults to 42. + attn_scale_factor (float): The scale factor to calculate qk scale in + attentions. Defaults to 2. + scale_embedding (bool): Whether to scale the embeddings by the square + root of the dimension. Defaults to False. + add_embedding_ln (bool): Whether to add an extra layer norm for token + embeddings. Defaults to True. + add_image_embedding_ln (bool): Whether to add an extra layer norm for + image embeddings. Defaults to True. + pre_norm (bool): Whether to do layer norm before attention and ffn + blocks in transformer layers. Defaults to True. + entangle_position_embedding (bool): Whether to add the position + embedding on the embeddings directly. Defaults to False. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + embed_tokens, + embed_images: dict, + num_layers=6, + num_heads=12, + dropout_rate=0., + drop_path_rate=0., + max_source_positions=1024, + token_bucket_size=256, + image_bucket_size=42, + attn_scale_factor=2., + scale_embedding=False, + add_embedding_ln=True, + add_type_embed=True, + add_image_embedding_ln=True, + pre_norm=True, + entangle_position_embedding=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + self.num_layers = num_layers + embedding_dim = embed_tokens.embedding_dim + self.embedding_dim = embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = max_source_positions + self.num_heads = num_heads + + # Build embedding process components + self.embed_tokens = embed_tokens + self.embedding_scale = math.sqrt( + embedding_dim) if scale_embedding else 1.0 + + if not isinstance(embed_images, nn.Module): + self.embed_images = MODELS.build(embed_images) + else: + self.embed_images = embed_images + self.image_proj = nn.Linear(1024, embedding_dim) + + if add_embedding_ln: + self.embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.embedding_ln = None + + if add_type_embed: + self.embed_type = nn.Embedding(2, embedding_dim) + else: + self.embed_type = None + + if add_image_embedding_ln: + self.image_embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.image_embedding_ln = None + + self.entangle_position_embedding = entangle_position_embedding + + # Build position embedding + self.embed_positions = nn.Embedding(self.max_source_positions + 2, + embedding_dim) + self.pos_ln = nn.LayerNorm(embedding_dim) + self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, + embedding_dim) + self.image_pos_ln = nn.LayerNorm(embedding_dim) + + self.pos_scaling = float(embedding_dim / num_heads * + attn_scale_factor)**-0.5 + self.pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate > 0. else nn.Identity() + + # Register token relative position embedding table + self.token_bucket_size = token_bucket_size + token_num_rel_dis = 2 * token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(token_bucket_size, + self.max_source_positions) + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.token_rel_pos_table_list = nn.ModuleList() + + # Register image relative position embedding table + self.image_bucket_size = image_bucket_size + image_num_rel_dis = (2 * image_bucket_size - + 1) * (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(image_bucket_size, + image_num_rel_dis) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.image_rel_pos_table_list = nn.ModuleList() + + # Build encoder layers + self.layers = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + for index in range(self.num_layers): + layer = OFAEncoderLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + dropout_rate=dropout_rate, + drop_path_rate=dpr[index], + scale_factor=attn_scale_factor, + pre_norm=pre_norm, + ) + self.layers.append(layer) + token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) + image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) + nn.init.constant_(token_pos_table.weight, 0.) + nn.init.constant_(image_pos_table.weight, 0.) + self.token_rel_pos_table_list.append(token_pos_table) + self.image_rel_pos_table_list.append(image_pos_table) + + if pre_norm: + self.final_ln = nn.LayerNorm(embedding_dim) + else: + self.final_ln = None + + main_input_name = 'input_ids' + + def forward(self, + input_ids, + images, + images_mask, + output_attentions=False, + output_hidden_states=False, + sample_patch_num=None): + padding_mask = input_ids.eq(self.padding_idx) + has_pads = padding_mask.any() + token_embedding = self.embed_tokens(input_ids) + token_embedding = self.embedding_scale * token_embedding + + # Embed the token position + src_pos_idx = torch.arange(input_ids.size(-1), device=input_ids.device) + src_pos_idx = src_pos_idx.expand(*input_ids.shape).contiguous() + pos_embedding = self.embed_positions(src_pos_idx) + + # Embed the input tokens + x = self.process_embedding( + embedding=token_embedding, + type_tokens=input_ids.new_zeros(token_embedding.shape[:2]), + pos_embedding=pos_embedding, + embedding_ln=self.embedding_ln, + ) + pos_embedding = self.pos_ln(pos_embedding) + + # Embed the input images + if images is not None: + (image_tokens, image_padding_mask, image_position_ids, + image_pos_embedding) = self.get_image_tokens( + images, + sample_patch_num, + images_mask, + ) + image_embedding = self.image_proj(image_tokens) + + image_x = self.process_embedding( + embedding=image_embedding, + type_tokens=input_ids.new_ones(image_embedding.shape[:2]), + pos_embedding=image_pos_embedding, + embedding_ln=self.image_embedding_ln, + ) + image_pos_embedding = self.image_pos_ln(image_pos_embedding) + + x = torch.cat([image_x, x], dim=1) + padding_mask = torch.cat([image_padding_mask, padding_mask], dim=1) + pos_embedding = torch.cat([image_pos_embedding, pos_embedding], + dim=1) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + # Decoupled position embedding + B, L = pos_embedding.shape[:2] + pos_q = self.pos_q_linear(pos_embedding).view( + B, L, self.num_heads, -1).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embedding).view(B, L, self.num_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + all_hidden_states = [] if output_hidden_states else None + all_attentions = [] if output_attentions else None + + for idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(x) + + self_attn_bias = abs_pos_bias.clone() + # Add decoupled position embedding for input tokens. + token_len = input_ids.size(1) + rel_pos_bias = self.get_rel_pos_bias(input_ids, idx) + self_attn_bias[:, :, -token_len:, -token_len:] += rel_pos_bias + + # Add decoupled position embedding for images + if images is not None: + token_len = image_tokens.size(1) + rel_pos_bias = self.get_image_rel_pos_bias( + image_position_ids, idx) + self_attn_bias[:, :, :token_len, :token_len] += rel_pos_bias + + if has_pads: + attention_mask = _expand_mask(padding_mask, dtype=x.dtype) + else: + attention_mask = None + + out = layer( + x, + attention_mask=attention_mask, + attn_bias=self_attn_bias, + output_attentions=output_attentions) + x = out[0] + + if output_attentions: + all_attentions.append(out[1]) + + if output_hidden_states: + all_hidden_states.append(x) + + if self.final_ln is not None: + x = self.final_ln(x) + + return OFAEncoderOutput( + last_hidden_state=x, # (B, L, C) + padding_mask=padding_mask, # (B, L) + position_embedding=pos_embedding, # (B, L, C) + hidden_states=all_hidden_states, # list of (B, L, C) + attentions=all_attentions, # list of (B, num_heads, L, head_dims) + ) + + def get_image_tokens(self, images, sample_patch_num, images_mask): + image_embedding = self.embed_images(images)[-1] + B, C, H, W = image_embedding.shape + num_patches = H * W + + padding_mask = images.new_zeros((B, num_patches)).bool() + position_col = torch.arange(W).unsqueeze(0) + position_row = torch.arange(H).unsqueeze(1) * self.image_bucket_size + position_idx = (position_col + position_row + 1).view(-1) + position_idx = position_idx.to(images.device).expand(B, num_patches) + + # (B, C, H, W) -> (B, C, H*W) -> (B, H*W, C) + image_embedding = image_embedding.flatten(2).transpose(1, 2) + if sample_patch_num is not None: + patch_orders = torch.stack([ + torch.randperm(num_patches)[:sample_patch_num] + for _ in range(B) + ]) + num_patches = sample_patch_num + image_embedding = image_embedding.gather( + dim=1, index=patch_orders.unsqueeze(2).expand(-1, -1, C)) + padding_mask = padding_mask.gather(1, patch_orders) + position_idx = position_idx.gather(1, patch_orders) + + pos_embedding = self.embed_image_positions(position_idx) + padding_mask[~images_mask] = True + return image_embedding, padding_mask, position_idx, pos_embedding + + def process_embedding(self, + embedding, + pos_embedding=None, + type_tokens=None, + embedding_ln=None): + if self.entangle_position_embedding and pos_embedding is not None: + embedding += pos_embedding + if self.embed_type is not None: + embedding += self.embed_type(type_tokens) + if embedding_ln is not None: + embedding = embedding_ln(embedding) + embedding = self.dropout(embedding) + + return embedding + + def get_rel_pos_bias(self, x, idx): + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + +class OFADecoder(BaseModule): + """The decoder module of OFA. + + Args: + embed_tokens (nn.Embedding): The embedding module to embed the + input tokens. + num_layers (int): The number of decoder layers. Defaults to 6. + num_heads (int): The number of heads of attention. Defaults to 12. + dropout_rate (float): The prob of dropout for embedding and + transformer layers. Defaults to 0. + drop_path_rate (float): The prob of droppath for transformer layers. + Defaults to 0. + max_target_positions (int): The maximum length of the input tokens. + Defaults to 1024. + code_image_size (int): The resolution of the generated image in the + image infilling task. Defaults to 128. + token_bucket_size (int): The token bucket size, it's used as the + maximum relative position index in relative position embedding + of input tokens. Defaults to 256. + image_bucket_size (int): The image bucket size, it's used to generate + the image relative position embedding table. It should be larger + than the shape of image feature map. Defaults to 42. + attn_scale_factor (float): The scale factor to calculate qk scale in + attentions. Defaults to 2. + scale_embedding (bool): Whether to scale the embeddings by the square + root of the dimension. Defaults to False. + add_embedding_ln (bool): Whether to add an extra layer norm for token + embeddings. Defaults to True. + add_code_embedding_ln (bool): Whether to add an extra layer norm for + code embeddings. Defaults to True. + pre_norm (bool): Whether to do layer norm before attention and ffn + blocks in transformer layers. Defaults to True. + entangle_position_embedding (bool): Whether to add the position + embedding on the embeddings directly. Defaults to False. + share_input_output_embed (bool): Share the weights of the input token + embedding module and the output projection module. + Defaults to True. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + embed_tokens, + num_layers=6, + num_heads=12, + dropout_rate=0., + drop_layer_rate=0., + drop_path_rate=0., + max_target_positions=1024, + code_image_size=128, + token_bucket_size=256, + image_bucket_size=42, + attn_scale_factor=2., + scale_embedding=False, + add_embedding_ln=True, + add_code_embedding_ln=True, + pre_norm=True, + entangle_position_embedding=False, + share_input_output_embed=True, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self._future_mask = torch.empty(0) + + self.num_layers = num_layers + embedding_dim = embed_tokens.embedding_dim + self.embedding_dim = embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = max_target_positions + self.num_heads = num_heads + + # Build embedding process components + self.embed_tokens = embed_tokens + self.embedding_scale = math.sqrt( + embedding_dim) if scale_embedding else 1.0 + + if add_embedding_ln: + self.embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.embedding_ln = None + + if add_code_embedding_ln: + self.code_embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.code_embedding_ln = None + + # Build position embedding + self.embed_positions = nn.Embedding(self.max_target_positions + 2, + embedding_dim) + self.pos_ln = nn.LayerNorm(embedding_dim) + self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, + embedding_dim) + self.image_pos_ln = nn.LayerNorm(embedding_dim) + + self.pos_scaling = float(embedding_dim / num_heads * + attn_scale_factor)**-0.5 + self.self_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.self_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + self.cross_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.cross_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + + self.entangle_position_embedding = entangle_position_embedding + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate > 0. else nn.Identity() + if drop_layer_rate > 0.: + raise NotImplementedError + + # Register token relative position embedding table + self.token_bucket_size = token_bucket_size + token_num_rel_dis = 2 * token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(token_bucket_size) + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.token_rel_pos_table_list = nn.ModuleList() + + # Register image relative position embedding table + self.image_bucket_size = image_bucket_size + image_num_rel_dis = (2 * image_bucket_size - + 1) * (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(image_bucket_size, + image_num_rel_dis) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.image_rel_pos_table_list = nn.ModuleList() + + self.window_size = code_image_size // 8 + + position_col = torch.arange(self.window_size).unsqueeze(0) + position_row = torch.arange( + self.window_size).unsqueeze(1) * self.image_bucket_size + image_position_idx = (position_col + position_row + 1) + image_position_idx = torch.cat( + [torch.tensor([0]), image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.register_buffer('image_position_idx', image_position_idx) + + # Build decoder layers + self.layers = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + for index in range(self.num_layers): + layer = OFADecoderLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + dropout_rate=dropout_rate, + drop_path_rate=dpr[index], + scale_factor=attn_scale_factor, + pre_norm=pre_norm, + ) + self.layers.append(layer) + token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) + image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) + nn.init.constant_(token_pos_table.weight, 0.) + nn.init.constant_(image_pos_table.weight, 0.) + self.token_rel_pos_table_list.append(token_pos_table) + self.image_rel_pos_table_list.append(image_pos_table) + + if pre_norm: + self.final_ln = nn.LayerNorm(embedding_dim) + else: + self.final_ln = None + + # Build output projection + if share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + vocab_size = self.embed_tokens.num_embeddings + self.output_projection = nn.Linear( + embedding_dim, vocab_size, bias=False) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=embedding_dim**-0.5, + ) + + main_input_name = 'input_ids' + + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + code_masks: Optional[torch.Tensor] = None, + encoder_pos_embedding: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + + if past_key_values is not None and len(past_key_values) > 0: + B, _, L_past, _ = past_key_values[0][0].shape + L = L_past + 1 + else: + B, L = input_ids.shape + L_past = 0 + + # Embed the token position + target_pos_idx = torch.arange( + L, device=input_ids.device).expand([B, L]).contiguous() + pos_embedding = self.embed_positions(target_pos_idx) + + # Embed the code positions + if code_masks is not None and torch.any(code_masks): + image_position_idx = self.image_position_idx[:input_ids.size(1)] + image_position_idx = image_position_idx.unsqueeze(0).expand(B, L) + pos_embedding[code_masks] = self.embed_image_positions( + image_position_idx)[code_masks] + + # Self-attention position bias (B, num_heads, L_t, L_t) + self_abs_pos_bias = self.get_pos_info(self.pos_ln(pos_embedding)) + if code_masks is not None and torch.any(code_masks): + self_image_abs_pos_bias = self.get_pos_info( + self.image_pos_ln(pos_embedding)) + self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] + + # Cross-attention position bias (B, num_heads, L_t, L_s) + cross_abs_pos_bias = self.get_pos_info( + self.pos_ln(pos_embedding), encoder_pos_embedding) + if code_masks is not None and torch.any(code_masks): + cross_image_abs_pos_bias = self.get_pos_info( + self.image_pos_ln(pos_embedding), encoder_pos_embedding) + cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[ + code_masks] + + all_prev_output_tokens = input_ids.clone() + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + cross_abs_pos_bias = cross_abs_pos_bias[:, :, -1:, :] + pos_embedding = pos_embedding[:, -1:, :] + + # Embed the input tokens + x = self.embed_tokens(input_ids) * self.embedding_scale + + if self.entangle_position_embedding: + x += pos_embedding + + if self.embedding_ln is not None: + if (code_masks is None or not code_masks.any() + or self.code_embedding_ln is None): + x = self.embedding_ln(x) + elif code_masks is not None and code_masks.all(): + x = self.code_embedding_ln(x) + else: + x[~code_masks] = self.embedding_ln(x[~code_masks]) + x[code_masks] = self.code_embedding_ln(x[code_masks]) + + x = self.dropout(x) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_ids.shape, x.dtype, L_past) + attention_mask = attention_mask.to(x.device) + + # decoder layers + all_hidden_states = [] if output_hidden_states else None + all_self_attns = [] if output_attentions else None + all_cross_attentions = [] if ( + output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = [] if use_cache else None + + for idx, layer in enumerate(self.layers): + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states.append(x) + + if past_key_values is not None and len(past_key_values) > 0: + past_key_value = past_key_values[idx] + else: + past_key_value = None + + self_attn_bias = self_abs_pos_bias.clone() + if code_masks is None or not code_masks.any(): + self_attn_bias += self.get_rel_pos_bias( + all_prev_output_tokens, idx) + elif code_masks is not None and code_masks.all(): + self_attn_bias += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx) + else: + self_attn_bias[~code_masks] += self.get_rel_pos_bias( + all_prev_output_tokens, idx) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx) + + if past_key_value is not None: + self_attn_bias = self_attn_bias[:, :, -1:, :] + + out = layer( + x, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + self_attn_bias=self_attn_bias, + cross_attn_bias=cross_abs_pos_bias, + ) + x = out.pop(0) + + if output_attentions: + all_self_attns.append(out.pop(0)) + if encoder_hidden_states is not None: + all_cross_attentions.append(out.pop(0)) + + if use_cache: + next_decoder_cache.append(out.pop(0)) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (x, ) + + if self.final_ln is not None: + x = self.final_ln(x) + + x = self.output_projection(x) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=x, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + def _prepare_decoder_attention_mask( + self, + attention_mask, + input_shape, + dtype, + past_key_values_length, + ): + r""" + Create causal mask for unidirectional decoding. + [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + dtype, + past_key_values_length=past_key_values_length).to( + attention_mask.device) + + if attention_mask is not None: + # (B, L_s) -> (B, 1, L_t, L_s) + expanded_attention_mask = _expand_mask( + attention_mask, dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attention_mask if combined_attention_mask is None else + expanded_attention_mask + combined_attention_mask) + + return combined_attention_mask + + def get_pos_info(self, pos_embedding, src_pos_embedding=None): + B, tgt_len = pos_embedding.shape[:2] + if src_pos_embedding is not None: + src_len = src_pos_embedding.size(1) + pos_q = self.cross_pos_q_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + pos_q = pos_q * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embedding).view( + B, src_len, self.num_heads, -1).transpose(1, 2) + else: + pos_q = self.self_pos_q_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + pos_q = pos_q * self.pos_scaling + pos_k = self.self_pos_k_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + return abs_pos_bias + + def get_rel_pos_bias(self, x, idx): + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + +class OFAEncoderDecoder(BaseModule, GenerationMixin): + """The OFA main architecture with an encoder and a decoder. + + Args: + encoder_cfg (dict): The config of the encoder, accept the keyword + arguments of :class:`OFAEncoder`. + decoder_cfg (dict): The config of the decoder, accept the keyword + arguments of :class:`OFADecoder`. + padding_idx (int): The index of the padding token. + vocab_size (int): The size of the vocabulary. + embedding_dim (int): The embedding dimensions of both the encoder + and the decoder. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of :class:`~transformers.GenerationConfig`. + Defaults to an empty dict. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + encoder_cfg, + decoder_cfg, + padding_idx, + vocab_size, + embedding_dim, + generation_cfg=dict(), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + self.padding_idx = padding_idx + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + embed_tokens = nn.Embedding(vocab_size, embedding_dim, padding_idx) + + self.encoder = OFAEncoder(embed_tokens, **encoder_cfg) + self.decoder = OFADecoder(embed_tokens, **decoder_cfg) + + self.config = PretrainedConfig( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + bos_token_id=0, + decoder_start_token_id=0, + pad_token_id=1, + eos_token_id=2, + forced_eos_token_id=2, + use_cache=False, + is_encoder_decoder=True, + ) + self.config.update(generation_cfg) + + self.generation_config = GenerationConfig.from_model_config( + self.config) + + @property + def device(self): + return next(self.parameters()).device + + def can_generate(self): + return True + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + def get_normalized_probs(self, net_output, log_probs: bool, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + def get_normalized_probs_scriptable( + self, + net_output, + log_probs: bool, + sample=None, + ): + """Scriptable helper function for get_normalized_probs in. + + ~BaseFairseqModel. + """ + if hasattr(self, 'decoder'): + return self.decoder.get_normalized_probs(net_output, log_probs, + sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + main_input_name = 'input_ids' + + def forward(self, + input_ids=None, + images=None, + images_mask=None, + sample_patch_num=None, + decoder_input_ids=None, + code_masks=None, + attention_mask=None, + encoder_outputs=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + constrain_fn=None, + return_dict=False): + """Forword the module. + + Args: + input_ids (torch.Tensor): The indices of the input tokens in the + vocabulary, and padding will be ignored by default. The indices + can be obtained using :class:`OFATokenizer`. + The shape is (B, L). + images (torch.Tensor): The input images. The shape is (B, 3, H, W). + images_mask (torch.Tensor): The mask of all available images. The + shape is (B, ). + sample_patch_num (int): The number of patches to sample for the + images. Defaults to None, which means to use all patches. + decoder_input_ids (torch.Tensor): The indices of the input tokens + for the decoder. + code_masks (torch.Tensor): The mask of all samples for image + generation. The shape is (B, ). + attention_mask (torch.Tensor): The attention mask for decoding. + The shape is (B, L). + encoder_outputs (OFAEncoderOutput): The encoder outputs with hidden + states, positional embeddings, and padding masks. + past_key_values (Tuple[Tuple[torch.Tensor]]): If use cache, the + parameter is a tuple of length ``num_layers``. Every item is + also a tuple with four tensors, two for the key and value of + self-attention, two for the key and value of cross-attention. + use_cache (bool): Whether to use cache for faster inference. + Defaults to False. + output_attentions (bool): Whether to output attention weights. + Defaults to False. + output_hidden_states (bool): Whether to output hidden states. + Defaults to False. + constrain_fn (Callable, optional): The function to constrain the + output logits. Defaults to None. + return_dict (bool): Not used, it's only for compat with the + interface of the ``generate`` of ``transformers``. + + Returns: + Seq2SeqLMOutput: + + - logits (``torch.Tensor``): The last decoder hidden states. + The shape is (B, L, C). + - past_key_values (``Tuple[Tuple[torch.Tensor]]``): The past keys + and values for faster inference. + - decoder_hidden_states (``Tuple[torch.Tensor]``): the decoder + hidden states of all layers. + - decoder_attentions (``Tuple[torch.Tensor]``): The self-attention + weights of all layers in the decoder. + - cross_attentions (``Tuple[torch.Tensor]``): The cross-attention + weights of all layers in the decoder. + - encoder_last_hidden_state (``torch.Tensor``): The last encoder + hidden states. + - encoder_hidden_states (``Tuple[torch.Tensor]``): The encoder + hidden states of all layers, including the embeddings. + - encoder_attentions (``Tuple[torch.Tensor]``): The self-attention + weights of all layers in the encoder. + """ + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + images=images, + images_mask=images_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + sample_patch_num=sample_patch_num, + ) + + if decoder_input_ids.eq(self.padding_idx).any(): + attention_mask = decoder_input_ids.eq(self.padding_idx) + + encoder_hidden_states = encoder_outputs.last_hidden_state + encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, + encoder_hidden_states.dtype, + decoder_input_ids.shape[-1]) + src_pos_embed = encoder_outputs.position_embedding + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_masks, + encoder_pos_embedding=src_pos_embed, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # The constrain operation for fine-tuned model in OFA is applied + # before log_softmax, therefore we cannot use + # `prefix_allowed_tokens_fn` to implement it. + if constrain_fn is not None: + logits = constrain_fn(decoder_input_ids, + decoder_outputs.last_hidden_state) + else: + logits = decoder_outputs.last_hidden_state + + return Seq2SeqLMOutput( + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + decoder_input_ids=None, + past=None, + attention_mask=None, + code_masks=None, + use_cache=False, + encoder_outputs=None, + constrain_fn=None, + **kwargs): + # if attention_mask is None: + attention_mask = decoder_input_ids.new_zeros(decoder_input_ids.shape) + + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + 'input_ids': None, + 'images': None, + 'images_mask': None, + 'sample_patch_num': None, + 'attention_mask': attention_mask, + 'encoder_outputs': encoder_outputs, + 'past_key_values': past, + 'decoder_input_ids': decoder_input_ids, + 'code_masks': code_masks, + 'use_cache': use_cache, + 'constrain_fn': constrain_fn, + } + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None): + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = [ + 'decoder_', 'cross_attn', 'use_cache', 'attention_mask', + 'constrain_fn' + ] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + if encoder_kwargs.get('images_mask') is None: + encoder_kwargs['images_mask'] = torch.tensor([True] * + inputs_tensor.size(0)) + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name or self.main_input_name + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs['encoder_outputs']: ModelOutput = encoder( + **encoder_kwargs) + model_kwargs['attention_mask'] = None + + return model_kwargs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat( + 1, expand_size).view(-1).to(input_ids.device)) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs['attention_mask'] = attention_mask.index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError('If `is_encoder_decoder` is True, make ' + 'sure that `encoder_outputs` is defined.') + encoder_outputs['last_hidden_state'] = encoder_outputs.\ + last_hidden_state.index_select(0, expanded_return_idx) + encoder_outputs['position_embedding'] = encoder_outputs.\ + position_embedding.index_select(0, expanded_return_idx) + encoder_outputs['padding_mask'] = encoder_outputs.\ + padding_mask.index_select(0, expanded_return_idx) + model_kwargs['encoder_outputs'] = encoder_outputs + return input_ids, model_kwargs diff --git a/mmpretrain/models/multimodal/otter/__init__.py b/mmpretrain/models/multimodal/otter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38a45a3d17458eae2471846b43498aa06cdfaac3 --- /dev/null +++ b/mmpretrain/models/multimodal/otter/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .otter import Otter + +__all__ = ['Otter'] diff --git a/mmpretrain/models/multimodal/otter/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/multimodal/otter/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598c47bece0f49b4762a066fcbd2ea1a53483357 Binary files /dev/null and b/mmpretrain/models/multimodal/otter/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/otter/__pycache__/otter.cpython-38.pyc b/mmpretrain/models/multimodal/otter/__pycache__/otter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9add819126e2a68df29e11665150782e7d0512 Binary files /dev/null and b/mmpretrain/models/multimodal/otter/__pycache__/otter.cpython-38.pyc differ diff --git a/mmpretrain/models/multimodal/otter/otter.py b/mmpretrain/models/multimodal/otter/otter.py new file mode 100644 index 0000000000000000000000000000000000000000..2fed1a4d27c34cf5367e6c4a670afc11f65b431f --- /dev/null +++ b/mmpretrain/models/multimodal/otter/otter.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler + + +@MODELS.register_module() +class Otter(Flamingo): + """The Otter model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to an. + shot_prompt_tmpl (str): Prompt used for few-shot inference. + Defaults to 'User:Please describe the image. + GPT:{caption}<|endofchunk|>'. + final_prompt_tmpl (str): Final part of prompt used for inference. + Defaults to 'User:Please describe the image. GPT:'. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + _no_split_modules = [ + 'TransformerEncoderLayer', 'PerceiverAttention', + 'GatedCrossAttentionBlock', 'FlamingoLayer' + ] + + def __init__( + self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + zeroshot_prompt: str = '', + shot_prompt_tmpl: str = ('User:Please describe the image. ' + 'GPT:{caption}<|endofchunk|>'), + final_prompt_tmpl: str = ('User:Please describe the image. ' + 'GPT:'), + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(Flamingo, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Otter special tokens to the tokenizer + self.tokenizer.add_special_tokens({ + 'additional_special_tokens': + ['<|endofchunk|>', '', ''] + }) + self.tokenizer.bos_token_id = 1 + if self.tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + self.tokenizer.add_special_tokens({'pad_token': ''}) + + # Template to format the prompt input + self.zeroshot_prompt = zeroshot_prompt + self.shot_prompt_tmpl = shot_prompt_tmpl + self.final_prompt_tmpl = final_prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + self.vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + + self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) + + # init language encoder related modules + self.lang_encoder = ExtendModule(**lang_encoder) + self.lang_encoder.resize_token_embeddings(len(self.tokenizer)) + self.lang_encoder.media_token_id = self.tokenizer.encode('')[-1] + + # other necessary parameters + self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1] + self.generation_cfg = generation_cfg + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_adapter_hook) + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = output + elif self.task == 'vqa': + data_sample.pred_answer = output + + return data_samples diff --git a/mmpretrain/models/necks/__init__.py b/mmpretrain/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2952a691758843436dd70ad6a11a390216ac724a --- /dev/null +++ b/mmpretrain/models/necks/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beitv2_neck import BEiTV2Neck +from .cae_neck import CAENeck +from .densecl_neck import DenseCLNeck +from .gap import GlobalAveragePooling +from .gem import GeneralizedMeanPooling +from .hr_fuse import HRFuseScales +from .itpn_neck import iTPNPretrainDecoder +from .linear_neck import LinearNeck +from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder +from .milan_neck import MILANPretrainDecoder +from .mixmim_neck import MixMIMPretrainDecoder +from .mocov2_neck import MoCoV2Neck +from .nonlinear_neck import NonLinearNeck +from .simmim_neck import SimMIMLinearDecoder +from .spark_neck import SparKLightDecoder +from .swav_neck import SwAVNeck + +__all__ = [ + 'GlobalAveragePooling', + 'GeneralizedMeanPooling', + 'HRFuseScales', + 'LinearNeck', + 'BEiTV2Neck', + 'CAENeck', + 'DenseCLNeck', + 'MAEPretrainDecoder', + 'ClsBatchNormNeck', + 'MILANPretrainDecoder', + 'MixMIMPretrainDecoder', + 'MoCoV2Neck', + 'NonLinearNeck', + 'SimMIMLinearDecoder', + 'SwAVNeck', + 'iTPNPretrainDecoder', + 'SparKLightDecoder', +] diff --git a/mmpretrain/models/necks/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d05728ca735f41e20046042bb8054194b54c647f Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/beitv2_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/beitv2_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a8a9769971feee2de0ac3d4469ba8d209d508f4 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/beitv2_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/cae_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/cae_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b9d105ef108c6bd4593b9fb2a311120cae3ac7 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/cae_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/densecl_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/densecl_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5bb58d625df2f76376ca8fe71d58343eb6af2e6 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/densecl_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/gap.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/gap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4491e739566c6fe10cefbafe81c25aa3c9552126 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/gap.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/gem.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/gem.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e8bb58d58cae300cf9c6ba65c8db71580ee45c0 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/gem.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/hr_fuse.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/hr_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fdfb70cdc60d031eb25acd65d81d71b8d81607d Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/hr_fuse.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/itpn_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/itpn_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e231453d6595976a0981ce7e16416c295e2ad61 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/itpn_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/linear_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/linear_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51f3c71702a44dbcf18b4fe62ddf637f6e8c848a Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/linear_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/mae_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/mae_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303329865a6603fa7efd7241d05ca26c54aec440 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/mae_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/milan_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/milan_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6649338d59ca1a33507835b234709aa449eb8336 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/milan_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/mixmim_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/mixmim_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d092772087976c8619c7c3729914dbdf43e8476b Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/mixmim_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/mocov2_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/mocov2_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25b52a4bf7a1f4df84c480feb4785d668074e739 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/mocov2_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/nonlinear_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/nonlinear_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..599e99ee4c60c69038434f7b3335c740c173dd8e Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/nonlinear_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/simmim_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/simmim_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ffe36b554c35c20dbfb8652c3946a75954ebedb Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/simmim_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/spark_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/spark_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d16545d68aa77b92f00b1db9108c46f48a57bf6d Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/spark_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/swav_neck.cpython-38.pyc b/mmpretrain/models/necks/__pycache__/swav_neck.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4bbb6a20b900706dbcf59d67f646a8b02e52c73 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/swav_neck.cpython-38.pyc differ diff --git a/mmpretrain/models/necks/beitv2_neck.py b/mmpretrain/models/necks/beitv2_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..745e3879f5e3a4b9269687797728354cb6cf7d4e --- /dev/null +++ b/mmpretrain/models/necks/beitv2_neck.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Neck(BaseModule): + """Neck for BEiTV2 Pre-training. + + This module construct the decoder for the final prediction. + + Args: + num_layers (int): Number of encoder layers of neck. Defaults to 2. + early_layers (int): The layer index of the early output from the + backbone. Defaults to 9. + backbone_arch (str): Vision Transformer architecture. Defaults to base. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initialization value for the + learnable scaling of attention and FFN. Defaults to 0.1. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'depth': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'depth': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + }), + } + + def __init__( + self, + num_layers: int = 2, + early_layers: int = 9, + backbone_arch: str = 'base', + drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.1, + use_rel_pos_bias: bool = False, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + + if isinstance(backbone_arch, str): + backbone_arch = backbone_arch.lower() + assert backbone_arch in set(self.arch_zoo), \ + (f'Arch {backbone_arch} is not in default archs ' + f'{set(self.arch_zoo)}') + self.arch_settings = self.arch_zoo[backbone_arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(backbone_arch, dict) and essential_keys <= set( + backbone_arch + ), f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = backbone_arch + + # stochastic depth decay rule + self.early_layers = early_layers + depth = self.arch_settings['depth'] + dpr = np.linspace(0, drop_path_rate, + max(depth, early_layers + num_layers)) + + self.patch_aggregation = nn.ModuleList() + for i in range(early_layers, early_layers + num_layers): + _layer_cfg = dict( + embed_dims=self.arch_settings['embed_dims'], + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + use_rel_pos_bias=use_rel_pos_bias) + self.patch_aggregation.append( + BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.rescale_patch_aggregation_init_weight() + + embed_dims = self.arch_settings['embed_dims'] + _, norm = build_norm_layer(norm_cfg, embed_dims) + self.add_module('norm', norm) + + def rescale_patch_aggregation_init_weight(self): + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.patch_aggregation): + rescale(layer.attn.proj.weight.data, + self.early_layers + layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, + self.early_layers + layer_id + 1) + + def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x (Tuple[torch.Tensor]): Features of tokens. + rel_pos_bias (torch.Tensor): Shared relative position bias table. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``x``: The final layer features from backbone, which are normed + in ``BEiTV2Neck``. + - ``x_cls_pt``: The early state features from backbone, which are + consist of final layer cls_token and early state patch_tokens + from backbone and sent to PatchAggregation layers in the neck. + """ + + early_states, x = inputs[0], inputs[1] + x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1) + for layer in self.patch_aggregation: + x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias) + + # shared norm + x, x_cls_pt = self.norm(x), self.norm(x_cls_pt) + + # remove cls_token + x = x[:, 1:] + x_cls_pt = x_cls_pt[:, 1:] + return x, x_cls_pt diff --git a/mmpretrain/models/necks/cae_neck.py b/mmpretrain/models/necks/cae_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..81fc30111362ca6f602a0d3f456fbc991926a99f --- /dev/null +++ b/mmpretrain/models/necks/cae_neck.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS +from ..utils import CrossMultiheadAttention + + +class CAETransformerRegressorLayer(BaseModule): + """Transformer layer for the regressor of CAE. + + This module is different from conventional transformer encoder layer, for + its queries are the masked tokens, but its keys and values are the + concatenation of the masked and unmasked tokens. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): The number of heads in multi-head attention. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + drop_rate (float): The dropout rate. Defaults to 0.0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The init value of gamma. + Defaults to 0.0. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + num_fcs: int = 2, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.0, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN', eps=1e-6) + ) -> None: + super().__init__() + + # NOTE: cross attention + _, self.norm1_q_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm1_k_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm1_v_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm2_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2) + self.cross_attn = CrossMultiheadAttention( + embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=drop_rate) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=None, + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = DropPath(drop_prob=drop_path_rate) + + if layer_scale_init_value > 0: + self.gamma_1_cross = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2_cross = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + else: + self.gamma_1_cross = nn.Parameter( + torch.ones((embed_dims)), requires_grad=False) + self.gamma_2_cross = nn.Parameter( + torch.ones((embed_dims)), requires_grad=False) + + def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, + pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn( + self.norm1_q_cross(x_q + pos_q), + k=self.norm1_k_cross(x_kv + pos_k), + v=self.norm1_v_cross(x_kv))) + x = self.norm2_cross(x) + x = x + self.drop_path(self.gamma_2_cross * self.ffn(x)) + + return x + + +@MODELS.register_module() +class CAENeck(BaseModule): + """Neck for CAE Pre-training. + + This module construct the latent prediction regressor and the decoder + for the latent prediction and final prediction. + + Args: + num_classes (int): The number of classes for final prediction. Defaults + to 8192. + embed_dims (int): The embed dims of latent feature in regressor and + decoder. Defaults to 768. + regressor_depth (int): The number of regressor blocks. Defaults to 6. + decoder_depth (int): The number of decoder blocks. Defaults to 8. + num_heads (int): The number of head in multi-head attention. Defaults + to 12. + mlp_ratio (int): The expand ratio of latent features in MLP. defaults + to 4. + qkv_bias (bool): Whether or not to use qkv bias. Defaults to True. + qk_scale (float, optional): The scale applied to the results of qk. + Defaults to None. + drop_rate (float): The dropout rate. Defaults to 0. + attn_drop_rate (float): The dropout rate in attention block. Defaults + to 0. + norm_cfg (dict): The config of normalization layer. Defaults to + dict(type='LN', eps=1e-6). + layer_scale_init_value (float, optional): The init value of gamma. + Defaults to None. + mask_tokens_num (int): The number of mask tokens. Defaults to 75. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_classes: int = 8192, + embed_dims: int = 768, + regressor_depth: int = 6, + decoder_depth: int = 8, + num_heads: int = 12, + mlp_ratio: int = 4, + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_cfg: dict = dict(type='LN', eps=1e-6), + layer_scale_init_value: float = None, + mask_tokens_num: int = 75, + init_cfg: dict = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.num_features = self.embed_dim = embed_dims + self.mask_token_num = mask_tokens_num + + # regressor + regressor_drop_path_rates = [ + x.item() + for x in torch.linspace(0, drop_path_rate, regressor_depth) + ] + self.regressors = nn.ModuleList([ + CAETransformerRegressorLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=regressor_drop_path_rates[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value) + for i in range(regressor_depth) + ]) + + # decoder + decoder_drop_path_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, decoder_depth) + ] + self.decoders = nn.ModuleList([ + BEiTTransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + # setting `use_rel_pos_bias` to False ignores the `window_size` + use_rel_pos_bias=False, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=decoder_drop_path_rates[i], + norm_cfg=norm_cfg) for i in range(decoder_depth) + ]) + + _, self.norm_regressor = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm_decoder = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + + self.head = nn.Linear( + embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + def init_weights(self) -> None: + """Initialization.""" + super().init_weights() + self.apply(self._init_weights) + trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.head.weight, std=0.02) + + def _init_weights(self, m: nn.Module) -> None: + """Initialization.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, x_unmasked: torch.Tensor, pos_embed_masked: torch.Tensor, + pos_embed_unmasked: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x_unmasked (torch.Tensor): Features of unmasked tokens. + pos_embed_masked (torch.Tensor): Position embedding of masked + tokens. + pos_embed_unmasked (torch.Tensor): Position embedding of unmasked + tokens. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``logits``: Final prediction. + - ``latent_pred``: Latent prediction. + """ + x_masked = self.mask_token.expand(x_unmasked.shape[0], + self.mask_token_num, -1) + # regressor + for regressor in self.regressors: + x_masked = regressor( + x_masked, torch.cat([x_unmasked, x_masked], dim=1), + pos_embed_masked, + torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1)) + x_masked = self.norm_regressor(x_masked) + latent_pred = x_masked + + # decoder + x_masked = x_masked + pos_embed_masked + for decoder in self.decoders: + x_masked = decoder(x_masked, rel_pos_bias=None) + x_masked = self.norm_decoder(x_masked) + + logits = self.head(x_masked) + + return logits, latent_pred diff --git a/mmpretrain/models/necks/densecl_neck.py b/mmpretrain/models/necks/densecl_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..bee9a2368d8917ece7b4b8ab8d1398ce951ede24 --- /dev/null +++ b/mmpretrain/models/necks/densecl_neck.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DenseCLNeck(BaseModule): + """The non-linear neck of DenseCL. + + Single and dense neck in parallel: fc-relu-fc, conv-relu-conv. + Borrowed from the authors' `code `_. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + num_grid (int): The grid size of dense features. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + hid_channels: int, + out_channels: int, + num_grid: Optional[int] = None, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + self.with_pool = True if num_grid is not None else False + if self.with_pool: + self.pool = nn.AdaptiveAvgPool2d((num_grid, num_grid)) + self.mlp2 = nn.Sequential( + nn.Conv2d(in_channels, hid_channels, 1), nn.ReLU(inplace=True), + nn.Conv2d(hid_channels, out_channels, 1)) + self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function of neck. + + Args: + x (Tuple[torch.Tensor]): feature map of backbone. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - ``avgpooled_x``: Global feature vectors. + - ``x``: Dense feature vectors. + - ``avgpooled_x2``: Dense feature vectors for queue. + """ + assert len(x) == 1 + x = x[0] + + avgpooled_x = self.avgpool(x) + avgpooled_x = self.mlp(avgpooled_x.view(avgpooled_x.size(0), -1)) + + if self.with_pool: + x = self.pool(x) # sxs + x = self.mlp2(x) # sxs: bxdxsxs + avgpooled_x2 = self.avgpool2(x) # 1x1: bxdx1x1 + x = x.view(x.size(0), x.size(1), -1) # bxdxs^2 + avgpooled_x2 = avgpooled_x2.view(avgpooled_x2.size(0), -1) # bxd + return avgpooled_x, x, avgpooled_x2 diff --git a/mmpretrain/models/necks/gap.py b/mmpretrain/models/necks/gap.py new file mode 100644 index 0000000000000000000000000000000000000000..0877743ad1e5a75976feb14f5d34942c0b7b8ee4 --- /dev/null +++ b/mmpretrain/models/necks/gap.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class GlobalAveragePooling(nn.Module): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. + Default: 2 + """ + + def __init__(self, dim=2): + super(GlobalAveragePooling, self).__init__() + assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \ + f'{1, 2, 3}, get {dim} instead.' + if dim == 1: + self.gap = nn.AdaptiveAvgPool1d(1) + elif dim == 2: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + else: + self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) + + def init_weights(self): + pass + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([self.gap(x) for x in inputs]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = self.gap(inputs) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpretrain/models/necks/gem.py b/mmpretrain/models/necks/gem.py new file mode 100644 index 0000000000000000000000000000000000000000..f5648be86303caa6f2c25786fe8c3058c2f98d7e --- /dev/null +++ b/mmpretrain/models/necks/gem.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.parameter import Parameter + +from mmpretrain.registry import MODELS + + +def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor: + if clamp: + x = x.clamp(min=eps) + return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + + +@MODELS.register_module() +class GeneralizedMeanPooling(nn.Module): + """Generalized Mean Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + p (float): Parameter value. Defaults to 3. + eps (float): epsilon. Defaults to 1e-6. + clamp (bool): Use clamp before pooling. Defaults to True + p_trainable (bool): Toggle whether Parameter p is trainable or not. + Defaults to True. + """ + + def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True): + assert p >= 1, "'p' must be a value greater than 1" + super(GeneralizedMeanPooling, self).__init__() + self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable) + self.eps = eps + self.clamp = clamp + self.p_trainable = p_trainable + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([ + gem(x, p=self.p, eps=self.eps, clamp=self.clamp) + for x in inputs + ]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpretrain/models/necks/hr_fuse.py b/mmpretrain/models/necks/hr_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..4a97f86f9fb9e4cce89e950e54674d5ec3d9b1f7 --- /dev/null +++ b/mmpretrain/models/necks/hr_fuse.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..backbones.resnet import Bottleneck, ResLayer + + +@MODELS.register_module() +class HRFuseScales(BaseModule): + """Fuse feature map of multiple scales in HRNet. + + Args: + in_channels (list[int]): The input channels of all scales. + out_channels (int): The channels of fused feature map. + Defaults to 2048. + norm_cfg (dict): dictionary to construct norm layers. + Defaults to ``dict(type='BN', momentum=0.1)``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``. + """ + + def __init__(self, + in_channels, + out_channels=2048, + norm_cfg=dict(type='BN', momentum=0.1), + init_cfg=dict(type='Normal', layer='Linear', std=0.01)): + super(HRFuseScales, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + + block_type = Bottleneck + out_channels = [128, 256, 512, 1024] + + # Increase the channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + increase_layers = [] + for i in range(len(in_channels)): + increase_layers.append( + ResLayer( + block_type, + in_channels=in_channels[i], + out_channels=out_channels[i], + num_blocks=1, + stride=1, + )) + self.increase_layers = nn.ModuleList(increase_layers) + + # Downsample feature maps in each scale. + downsample_layers = [] + for i in range(len(in_channels) - 1): + downsample_layers.append( + ConvModule( + in_channels=out_channels[i], + out_channels=out_channels[i + 1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + bias=False, + )) + self.downsample_layers = nn.ModuleList(downsample_layers) + + # The final conv block before final classifier linear layer. + self.final_layer = ConvModule( + in_channels=out_channels[3], + out_channels=self.out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + bias=False, + ) + + def forward(self, x): + assert isinstance(x, tuple) and len(x) == len(self.in_channels) + + feat = self.increase_layers[0](x[0]) + for i in range(len(self.downsample_layers)): + feat = self.downsample_layers[i](feat) + \ + self.increase_layers[i + 1](x[i + 1]) + + return (self.final_layer(feat), ) diff --git a/mmpretrain/models/necks/itpn_neck.py b/mmpretrain/models/necks/itpn_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3626af634b185fef9b0b2fb47c1fdc15e1139b --- /dev/null +++ b/mmpretrain/models/necks/itpn_neck.py @@ -0,0 +1,388 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.models.backbones.hivit import BlockWithRPE +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import build_2d_sincos_position_embedding + + +class PatchSplit(nn.Module): + """The up-sample module used in neck (transformer pyramid network) + + Args: + dim (int): the input dimension (channel number). + fpn_dim (int): the fpn dimension (channel number). + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__(self, dim, fpn_dim, norm_cfg): + super().__init__() + _, self.norm = build_norm_layer(norm_cfg, dim) + self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False) + self.fpn_dim = fpn_dim + + def forward(self, x): + B, N, H, W, C = x.shape + x = self.norm(x) + x = self.reduction(x) + x = x.reshape(B, N, H, W, 2, 2, + self.fpn_dim).permute(0, 1, 2, 4, 3, 5, + 6).reshape(B, N, 2 * H, 2 * W, + self.fpn_dim) + return x + + +@MODELS.register_module() +class iTPNPretrainDecoder(BaseModule): + """The neck module of iTPN (transformer pyramid network). + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 512. + fpn_dim (int): The fpn dimension (channel number). + fpn_depth (int): The layer number of feature pyramid. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + reconstruction_type (str): The itpn supports 2 kinds of supervisions. + Defaults to 'pixel'. + num_outs (int): The output number of neck (transformer pyramid + network). Defaults to 3. + predict_feature_dim (int): The output dimension to supervision. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 512, + fpn_dim: int = 256, + fpn_depth: int = 2, + decoder_embed_dim: int = 512, + decoder_depth: int = 6, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + reconstruction_type: str = 'pixel', + num_outs: int = 3, + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + predict_feature_dim: Optional[float] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_patches = num_patches + assert reconstruction_type in ['pixel', 'clip'], \ + 'iTPN method only support `pixel` and `clip`, ' \ + f'but got `{reconstruction_type}`.' + self.reconstruction_type = reconstruction_type + self.num_outs = num_outs + + self.build_transformer_pyramid( + num_outs=num_outs, + embed_dim=embed_dim, + fpn_dim=fpn_dim, + fpn_depth=fpn_depth, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + rpe=False, + norm_cfg=norm_cfg, + ) + + # merge the output + self.decoder_embed = nn.ModuleList() + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim, bias=True), + )) + + if self.num_outs >= 2: + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True), + )) + if self.num_outs >= 3: + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True), + )) + + if reconstruction_type == 'pixel': + self.mask_token = nn.Parameter( + torch.zeros(1, 1, decoder_embed_dim)) + + # create new position embedding, different from that in encoder + # and is not learnable + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, decoder_embed_dim), + requires_grad=False) + + self.decoder_blocks = nn.ModuleList([ + TransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + self.decoder_norm_name, decoder_norm = build_norm_layer( + norm_cfg, decoder_embed_dim, postfix=1) + self.add_module(self.decoder_norm_name, decoder_norm) + + # Used to map features to pixels + if predict_feature_dim is None: + predict_feature_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + else: + _, norm = build_norm_layer(norm_cfg, embed_dim) + self.add_module('norm', norm) + + def build_transformer_pyramid(self, + num_outs=3, + embed_dim=512, + fpn_dim=256, + fpn_depth=2, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + rpe=False, + norm_cfg=None): + Hp = None + mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim} + if num_outs > 1: + if embed_dim != fpn_dim: + self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim) + else: + self.align_dim_16tofpn = None + self.fpn_modules = nn.ModuleList() + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg)) + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=False, + norm_cfg=norm_cfg, + )) + + self.align_dim_16to8 = nn.Linear( + mlvl_dims['8'], fpn_dim, bias=False) + self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg) + self.block_16to8 = nn.Sequential(*[ + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg, + ) for _ in range(fpn_depth) + ]) + + if num_outs > 2: + self.align_dim_8to4 = nn.Linear( + mlvl_dims['4'], fpn_dim, bias=False) + self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg) + self.block_8to4 = nn.Sequential(*[ + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg, + ) for _ in range(fpn_depth) + ]) + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg)) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MAE decoder.""" + super().init_weights() + + if self.reconstruction_type == 'pixel': + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=False) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + else: + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.fpn_modules): + if isinstance(layer, BlockWithRPE): + if layer.attn is not None: + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + @property + def decoder_norm(self): + """The normalization layer of decoder.""" + return getattr(self, self.decoder_norm_name) + + def forward(self, + x: torch.Tensor, + ids_restore: torch.Tensor = None) -> torch.Tensor: + """The forward function. + + The process computes the visible patches' features vectors and the mask + tokens to output feature vectors, which will be used for + reconstruction. + + Args: + x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + ids_restore (torch.Tensor): ids to restore original image. + + Returns: + torch.Tensor: The reconstructed feature vectors, which is of + shape B x (num_patches) x C. + """ + + features = x[:2] + x = x[-1] + B, L, _ = x.shape + x = x[..., None, None, :] + Hp = Wp = math.sqrt(L) + + outs = [x] if self.align_dim_16tofpn is None else [ + self.align_dim_16tofpn(x) + ] + if self.num_outs >= 2: + x = self.block_16to8( + self.split_16to8(x) + self.align_dim_16to8(features[1])) + outs.append(x) + if self.num_outs >= 3: + x = self.block_8to4( + self.split_8to4(x) + self.align_dim_8to4(features[0])) + outs.append(x) + if self.num_outs > 3: + outs = [ + out.reshape(B, Hp, Wp, *out.shape[-3:]).permute( + 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3], + Wp * out.shape[-2]).contiguous() + for out in outs + ] + if self.num_outs >= 4: + outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) + if self.num_outs >= 5: + outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) + + for i, out in enumerate(outs): + out = self.fpn_modules[i](out) + outs[i] = out + + if self.reconstruction_type == 'pixel': + feats = [] + for feat, layer in zip(outs, self.decoder_embed): + x = layer(feat).reshape(B, L, -1) + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x = torch.cat([x, mask_tokens], dim=1) + x = torch.gather( + x, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + feats.append(x) + x = feats.pop(0) + # add pos embed + x = x + self.decoder_pos_embed + + for i, feat in enumerate(feats): + x = x + feats[i] + # apply Transformer blocks + for i, blk in enumerate(self.decoder_blocks): + x = blk(x) + x = self.decoder_norm(x) + x = self.decoder_pred(x) + return x + else: + feats = [] + for feat, layer in zip(outs, self.decoder_embed): + x = layer(feat).reshape(B, L, -1) + feats.append(x) + x = feats.pop(0) + for i, feat in enumerate(feats): + x = x + feats[i] + + x = self.norm(x) + + return x diff --git a/mmpretrain/models/necks/linear_neck.py b/mmpretrain/models/necks/linear_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdbee264325c8db0a054f765651a5dbadc968db --- /dev/null +++ b/mmpretrain/models/necks/linear_neck.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class LinearNeck(BaseModule): + """Linear neck with Dimension projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + gap_dim (int): Dimensions of each sample channel, can be one of + {0, 1, 2, 3}. Defaults to 0. + norm_cfg (dict, optional): dictionary to construct and + config norm layer. Defaults to dict(type='BN1d'). + act_cfg (dict, optional): dictionary to construct and + config activate layer. Defaults to None. + init_cfg (dict, optional): dictionary to initialize weights. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + gap_dim: int = 0, + norm_cfg: Optional[dict] = dict(type='BN1d'), + act_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.norm_cfg = copy.deepcopy(norm_cfg) + self.act_cfg = copy.deepcopy(act_cfg) + + assert gap_dim in [0, 1, 2, 3], 'GlobalAveragePooling dim only ' \ + f'support {0, 1, 2, 3}, get {gap_dim} instead.' + if gap_dim == 0: + self.gap = nn.Identity() + elif gap_dim == 1: + self.gap = nn.AdaptiveAvgPool1d(1) + elif gap_dim == 2: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + elif gap_dim == 3: + self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) + + self.fc = nn.Linear(in_features=in_channels, out_features=out_channels) + + if norm_cfg: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + else: + self.norm = nn.Identity() + + if act_cfg: + self.act = build_activation_layer(act_cfg) + else: + self.act = nn.Identity() + + def forward(self, inputs: Union[Tuple, + torch.Tensor]) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Union[Tuple, torch.Tensor]): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + + Returns: + Tuple[torch.Tensor]: A tuple of output features. + """ + assert isinstance(inputs, (tuple, torch.Tensor)), ( + 'The inputs of `LinearNeck` must be tuple or `torch.Tensor`, ' + f'but get {type(inputs)}.') + if isinstance(inputs, tuple): + inputs = inputs[-1] + + x = self.gap(inputs) + x = x.view(x.size(0), -1) + out = self.act(self.norm(self.fc(x))) + return (out, ) diff --git a/mmpretrain/models/necks/mae_neck.py b/mmpretrain/models/necks/mae_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..773692dcb3a94d85d2d2085360fd339493a24db3 --- /dev/null +++ b/mmpretrain/models/necks/mae_neck.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import build_2d_sincos_position_embedding + + +@MODELS.register_module() +class MAEPretrainDecoder(BaseModule): + """Decoder for MAE Pre-training. + + Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + + Example: + >>> from mmpretrain.models import MAEPretrainDecoder + >>> import torch + >>> self = MAEPretrainDecoder() + >>> self.eval() + >>> inputs = torch.rand(1, 50, 1024) + >>> ids_restore = torch.arange(0, 196).unsqueeze(0) + >>> level_outputs = self.forward(inputs, ids_restore) + >>> print(tuple(level_outputs.shape)) + (1, 196, 768) + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + predict_feature_dim: Optional[float] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_patches = num_patches + + # used to convert the dim of features from encoder to the dim + # compatible with that of decoder + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + # create new position embedding, different from that in encoder + # and is not learnable + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, decoder_embed_dim), + requires_grad=False) + + self.decoder_blocks = nn.ModuleList([ + TransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + self.decoder_norm_name, decoder_norm = build_norm_layer( + norm_cfg, decoder_embed_dim, postfix=1) + self.add_module(self.decoder_norm_name, decoder_norm) + + # Used to map features to pixels + if predict_feature_dim is None: + predict_feature_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MAE decoder.""" + super().init_weights() + + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=True) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + + @property + def decoder_norm(self): + """The normalization layer of decoder.""" + return getattr(self, self.decoder_norm_name) + + def forward(self, x: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """The forward function. + + The process computes the visible patches' features vectors and the mask + tokens to output feature vectors, which will be used for + reconstruction. + + Args: + x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + ids_restore (torch.Tensor): ids to restore original image. + + Returns: + torch.Tensor: The reconstructed feature vectors, which is of + shape B x (num_patches) x C. + """ + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + x = torch.cat([x[:, :1, :], x_], dim=1) + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + +@MODELS.register_module() +class ClsBatchNormNeck(BaseModule): + """Normalize cls token across batch before head. + + This module is proposed by MAE, when running linear probing. + + Args: + input_features (int): The dimension of features. + affine (bool): a boolean value that when set to ``True``, this module + has learnable affine parameters. Defaults to False. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-6. + init_cfg (Dict or List[Dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + input_features: int, + affine: bool = False, + eps: float = 1e-6, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps) + + def forward( + self, + inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]: + """The forward function.""" + # Only apply batch norm to cls_token + inputs = [self.bn(input_) for input_ in inputs] + return tuple(inputs) diff --git a/mmpretrain/models/necks/milan_neck.py b/mmpretrain/models/necks/milan_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..b48b76787231cfe9671e9f12900b6db1987a7e2a --- /dev/null +++ b/mmpretrain/models/necks/milan_neck.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from torch import nn + +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import PromptMultiheadAttention +from .mae_neck import MAEPretrainDecoder + + +class PromptTransformerEncoderLayer(TransformerEncoderLayer): + """Prompt Transformer Encoder Layer for MILAN. + + This module is specific for the prompt encoder in MILAN. It will not update + the visible tokens from the encoder. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): Enable bias for qkv if True. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels=int, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + self.attn = PromptMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias) + + def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """Forward function for `PromptMultiheadAttention`. + + Args: + x (torch.Tensor): Mask token features with shape N x L_m x C. + visible_tokens (torch.Tensor): The visible tokens features from + encoder with shape N x L_v x C. + ids_restore (torch.Tensor): The ids of all tokens in the original + image with shape N x L. + + Returns: + torch Tensor: Output features with shape N x L x C. + """ + x = x + self.attn(self.norm1(x), visible_tokens, ids_restore) + x = self.ffn(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class MILANPretrainDecoder(MAEPretrainDecoder): + """Prompt decoder for MILAN. + + This decoder is used in MILAN pretraining, which will not update these + visible tokens from the encoder. + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + predict_feature_dim (int): The dimension of the feature to be + predicted. Defaults to 512. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + predict_feature_dim: int = 512, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + num_patches=num_patches, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + decoder_depth=decoder_depth, + decoder_num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + # map the dim of features from decoder to the dim compatible with + # that of CLIP + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + + # use prompt transformer encoder layer, instead of the conventional + # transformer encoder layer + self.decoder_blocks = nn.ModuleList([ + PromptTransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + def forward(self, x: torch.Tensor, ids_restore: torch.Tensor, + ids_keep: torch.Tensor, + ids_dump: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): The input features, which is of shape (N, L, C). + ids_restore (torch.Tensor): The indices to restore these tokens + to the original image. + ids_keep (torch.Tensor): The indices of tokens to be kept. + ids_dump (torch.Tensor): The indices of tokens to be masked. + + Returns: + torch.Tensor: The reconstructed features, which is of shape + (N, L, C). + """ + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + x = torch.cat([x[:, :1, :], x_], dim=1) + + # add pos embed + x = x + self.decoder_pos_embed + + # split mask tokens and visible tokens + visible_tokens = torch.cat([ + x[:, :1, :], + torch.gather( + x[:, 1:, :], + dim=1, + index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + ], + dim=1) + x = torch.gather( + x[:, 1:, :], + dim=1, + index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + + for blk in self.decoder_blocks: + x = blk(x, visible_tokens, ids_restore) + + # full sequence recovery + x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, + x.shape[-1])) # unshuffle + x = torch.cat([visible_tokens[:, :1, :], x_], dim=1) + + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x diff --git a/mmpretrain/models/necks/mixmim_neck.py b/mmpretrain/models/necks/mixmim_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..8d67ee2bd6b48136f2ae6b298e11bd7758fa414b --- /dev/null +++ b/mmpretrain/models/necks/mixmim_neck.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from ..utils import build_2d_sincos_position_embedding +from .mae_neck import MAEPretrainDecoder + + +@MODELS.register_module() +class MixMIMPretrainDecoder(MAEPretrainDecoder): + """Decoder for MixMIM Pretraining. + + Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + encoder_stride (int): The output stride of MixMIM backbone. Defaults + to 32. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + encoder_stride: int = 32, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + num_patches=num_patches, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + decoder_depth=decoder_depth, + decoder_num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, decoder_embed_dim), + requires_grad=False) + self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MixMIM decoder.""" + super(MAEPretrainDecoder, self).init_weights() + + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=False) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): The input features, which is of shape (N, L, C). + mask (torch.Tensor): The tensor to indicate which tokens a + re masked. + + Returns: + torch.Tensor: The reconstructed features, which is of shape + (N, L, C). + """ + + x = self.decoder_embed(x) + B, L, C = x.shape + + mask_tokens = self.mask_token.expand(B, L, -1) + x1 = x * (1 - mask) + mask_tokens * mask + x2 = x * mask + mask_tokens * (1 - mask) + x = torch.cat([x1, x2], dim=0) + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for idx, blk in enumerate(self.decoder_blocks): + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x diff --git a/mmpretrain/models/necks/mocov2_neck.py b/mmpretrain/models/necks/mocov2_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad9107812eb9aaaaff8cbc1a7d5c3d39e92dfa1 --- /dev/null +++ b/mmpretrain/models/necks/mocov2_neck.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV2Neck(BaseModule): + """The non-linear neck of MoCo v2: fc-relu-fc. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + with_avg_pool (bool): Whether to apply the global + average pooling after backbone. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + hid_channels: int, + out_channels: int, + with_avg_pool: bool = True, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function. + + Args: + x (Tuple[torch.Tensor]): The feature map of backbone. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + assert len(x) == 1 + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) + return (self.mlp(x.view(x.size(0), -1)), ) diff --git a/mmpretrain/models/necks/nonlinear_neck.py b/mmpretrain/models/necks/nonlinear_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..ef684d39d1f7f5dc7361ccbf631d3ce712d65ac5 --- /dev/null +++ b/mmpretrain/models/necks/nonlinear_neck.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class NonLinearNeck(BaseModule): + """The non-linear neck. + + Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated. + For the default setting, the repeated time is 1. + The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + num_layers (int): Number of fc layers. Defaults to 2. + with_bias (bool): Whether to use bias in fc layers (except for the + last). Defaults to False. + with_last_bn (bool): Whether to add the last BN layer. + Defaults to True. + with_last_bn_affine (bool): Whether to have learnable affine parameters + in the last BN layer (set False for SimSiam). Defaults to True. + with_last_bias (bool): Whether to use bias in the last fc layer. + Defaults to False. + with_avg_pool (bool): Whether to apply the global average pooling + after backbone. Defaults to True. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + hid_channels: int, + out_channels: int, + num_layers: int = 2, + with_bias: bool = False, + with_last_bn: bool = True, + with_last_bn_affine: bool = True, + with_last_bias: bool = False, + with_avg_pool: bool = True, + norm_cfg: dict = dict(type='SyncBN'), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + super(NonLinearNeck, self).__init__(init_cfg) + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.relu = nn.ReLU(inplace=True) + self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias) + self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1] + + self.fc_names = [] + self.bn_names = [] + for i in range(1, num_layers): + this_channels = out_channels if i == num_layers - 1 \ + else hid_channels + if i != num_layers - 1: + self.add_module( + f'fc{i}', + nn.Linear(hid_channels, this_channels, bias=with_bias)) + self.add_module(f'bn{i}', + build_norm_layer(norm_cfg, this_channels)[1]) + self.bn_names.append(f'bn{i}') + else: + self.add_module( + f'fc{i}', + nn.Linear( + hid_channels, this_channels, bias=with_last_bias)) + if with_last_bn: + self.add_module( + f'bn{i}', + build_norm_layer( + dict(**norm_cfg, affine=with_last_bn_affine), + this_channels)[1]) + self.bn_names.append(f'bn{i}') + else: + self.bn_names.append(None) + self.fc_names.append(f'fc{i}') + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function. + + Args: + x (Tuple[torch.Tensor]): The feature map of backbone. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + assert len(x) == 1 + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc0(x) + x = self.bn0(x) + for fc_name, bn_name in zip(self.fc_names, self.bn_names): + fc = getattr(self, fc_name) + x = self.relu(x) + x = fc(x) + if bn_name is not None: + bn = getattr(self, bn_name) + x = bn(x) + return (x, ) diff --git a/mmpretrain/models/necks/simmim_neck.py b/mmpretrain/models/necks/simmim_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1e29bcf195ecb800a22a2c43917e62718b5ffe --- /dev/null +++ b/mmpretrain/models/necks/simmim_neck.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMLinearDecoder(BaseModule): + """Linear Decoder For SimMIM pretraining. + + This neck reconstructs the original image from the shrunk feature map. + + Args: + in_channels (int): Channel dimension of the feature map. + encoder_stride (int): The total stride of the encoder. + """ + + def __init__(self, in_channels: int, encoder_stride: int) -> None: + super().__init__() + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=encoder_stride**2 * 3, + kernel_size=1), + nn.PixelShuffle(encoder_stride), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = self.decoder(x) + return x diff --git a/mmpretrain/models/necks/spark_neck.py b/mmpretrain/models/necks/spark_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..ac129da389711f900e4444fae38fdbc7ae91b9e5 --- /dev/null +++ b/mmpretrain/models/necks/spark_neck.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def is_pow2n(x): + return x > 0 and (x & (x - 1) == 0) + + +class ConvBlock2x(BaseModule): + """The definition of convolution block.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + norm_cfg: dict, + act_cfg: dict, + last_act: bool, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False) + self.norm1 = build_norm_layer(norm_cfg, mid_channels) + self.activate1 = MODELS.build(act_cfg) + + self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False) + self.norm2 = build_norm_layer(norm_cfg, out_channels) + self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity() + + def forward(self, x: torch.Tensor): + out = self.conv1(x) + out = self.norm1(out) + out = self.activate1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.activate2(out) + return out + + +class DecoderConvModule(BaseModule): + """The convolution module of decoder with upsampling.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = True, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + padding = (kernel_size - scale_factor) // 2 + self.upsample = nn.ConvTranspose2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=scale_factor, + padding=padding, + bias=True) + + conv_blocks_list = [ + ConvBlock2x( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + norm_cfg=norm_cfg, + last_act=last_act, + act_cfg=act_cfg) for _ in range(num_conv_blocks) + ] + self.conv_blocks = nn.Sequential(*conv_blocks_list) + + def forward(self, x): + x = self.upsample(x) + return self.conv_blocks(x) + + +@MODELS.register_module() +class SparKLightDecoder(BaseModule): + """The decoder for SparK, which upsamples the feature maps. + + Args: + feature_dim (int): The dimension of feature map. + upsample_ratio (int): The ratio of upsample, equal to downsample_raito + of the algorithm. + mid_channels (int): The middle channel of `DecoderConvModule`. Defaults + to 0. + kernel_size (int): The kernel size of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 4. + scale_factor (int): The scale_factor of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 2. + num_conv_blocks (int): The number of convolution blocks in + `DecoderConvModule`. Defaults to 1. + norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN'). + act_cfg (dict): Activation config. Defaults to dict(type='ReLU6'). + last_act (bool): Whether apply the last activation in + `DecoderConvModule`. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + feature_dim: int, + upsample_ratio: int, + mid_channels: int = 0, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']), + dict(type='TruncNormal', std=0.02, layer=['Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm']) + ], + ): + super().__init__(init_cfg=init_cfg) + self.feature_dim = feature_dim + + assert is_pow2n(upsample_ratio) + n = round(math.log2(upsample_ratio)) + channels = [feature_dim // 2**i for i in range(n + 1)] + + self.decoder = nn.ModuleList([ + DecoderConvModule( + in_channels=c_in, + out_channels=c_out, + mid_channels=c_in if mid_channels == 0 else mid_channels, + kernel_size=kernel_size, + scale_factor=scale_factor, + num_conv_blocks=num_conv_blocks, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + last_act=last_act) + for (c_in, c_out) in zip(channels[:-1], channels[1:]) + ]) + self.proj = nn.Conv2d( + channels[-1], 3, kernel_size=1, stride=1, bias=True) + + def forward(self, to_dec): + x = 0 + for i, d in enumerate(self.decoder): + if i < len(to_dec) and to_dec[i] is not None: + x = x + to_dec[i] + x = self.decoder[i](x) + return self.proj(x) diff --git a/mmpretrain/models/necks/swav_neck.py b/mmpretrain/models/necks/swav_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..807ae8b9b3155e9dd14ef95fe5fca526919ee11d --- /dev/null +++ b/mmpretrain/models/necks/swav_neck.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SwAVNeck(BaseModule): + """The non-linear neck of SwAV: fc-bn-relu-fc-normalization. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + with_avg_pool (bool): Whether to apply the global average pooling after + backbone. Defaults to True. + with_l2norm (bool): whether to normalize the output after projection. + Defaults to True. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + hid_channels: int, + out_channels: int, + with_avg_pool: bool = True, + with_l2norm: bool = True, + norm_cfg: dict = dict(type='SyncBN'), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + super().__init__(init_cfg) + self.with_avg_pool = with_avg_pool + self.with_l2norm = with_l2norm + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if out_channels == 0: + self.projection_neck = nn.Identity() + elif hid_channels == 0: + self.projection_neck = nn.Linear(in_channels, out_channels) + else: + self.norm = build_norm_layer(norm_cfg, hid_channels)[1] + self.projection_neck = nn.Sequential( + nn.Linear(in_channels, hid_channels), + self.norm, + nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels), + ) + + def forward_projection(self, x: torch.Tensor) -> torch.Tensor: + """Compute projection. + + Args: + x (torch.Tensor): The feature vectors after pooling. + + Returns: + torch.Tensor: The output features with projection or L2-norm. + """ + x = self.projection_neck(x) + if self.with_l2norm: + x = nn.functional.normalize(x, dim=1, p=2) + return x + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + + Args: + x (List[torch.Tensor]): list of feature maps, len(x) according to + len(num_crops). + + Returns: + torch.Tensor: The projection vectors. + """ + avg_out = [] + for _x in x: + _x = _x[0] + if self.with_avg_pool: + _out = self.avgpool(_x) + avg_out.append(_out) + feat_vec = torch.cat(avg_out) # [sum(num_crops) * N, C] + feat_vec = feat_vec.view(feat_vec.size(0), -1) + output = self.forward_projection(feat_vec) + return output diff --git a/mmpretrain/models/retrievers/__init__.py b/mmpretrain/models/retrievers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..593b637d6eb7e44184fdf6ceb70470253639b013 --- /dev/null +++ b/mmpretrain/models/retrievers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseRetriever +from .image2image import ImageToImageRetriever + +__all__ = ['BaseRetriever', 'ImageToImageRetriever'] diff --git a/mmpretrain/models/retrievers/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/retrievers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63fddefd3945e9bbc3b9bbd31e5dbfa36d38694d Binary files /dev/null and b/mmpretrain/models/retrievers/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/retrievers/__pycache__/base.cpython-38.pyc b/mmpretrain/models/retrievers/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61e9f06c128b24334d48ee7664d0579cfe9d59ad Binary files /dev/null and b/mmpretrain/models/retrievers/__pycache__/base.cpython-38.pyc differ diff --git a/mmpretrain/models/retrievers/__pycache__/image2image.cpython-38.pyc b/mmpretrain/models/retrievers/__pycache__/image2image.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b578e815f7765380bc2d12477e6b2f06e59c0294 Binary files /dev/null and b/mmpretrain/models/retrievers/__pycache__/image2image.cpython-38.pyc differ diff --git a/mmpretrain/models/retrievers/base.py b/mmpretrain/models/retrievers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..15816798f3fadc612b51634994178eb5f8860fb8 --- /dev/null +++ b/mmpretrain/models/retrievers/base.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch.utils.data import DataLoader + + +class BaseRetriever(BaseModel, metaclass=ABCMeta): + """Base class for retriever. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + Attributes: + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__( + self, + prototype: Union[DataLoader, dict, str, torch.Tensor] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + ): + super(BaseRetriever, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + self.prototype = prototype + self.prototype_inited = False + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'loss'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def loss(self, inputs: torch.Tensor, + data_samples: List[BaseDataElement]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + raise NotImplementedError + + def predict(self, + inputs: tuple, + data_samples: Optional[List[BaseDataElement]] = None, + **kwargs) -> List[BaseDataElement]: + """Predict results from the extracted features. + + Args: + inputs (tuple): The features extracted from the backbone. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + raise NotImplementedError + + def matching(self, inputs: torch.Tensor): + """Compare the prototype and calculate the similarity. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C). + """ + raise NotImplementedError + + def prepare_prototype(self): + """Preprocessing the prototype before predict.""" + raise NotImplementedError + + def dump_prototype(self, path): + """Save the features extracted from the prototype to the specific path. + + Args: + path (str): Path to save feature. + """ + raise NotImplementedError diff --git a/mmpretrain/models/retrievers/image2image.py b/mmpretrain/models/retrievers/image2image.py new file mode 100644 index 0000000000000000000000000000000000000000..a00c1dceb102ee692c44090b62dcfa19dc441f3b --- /dev/null +++ b/mmpretrain/models/retrievers/image2image.py @@ -0,0 +1,314 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +from mmengine.runner import Runner +from torch.utils.data import DataLoader + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .base import BaseRetriever + + +@MODELS.register_module() +class ImageToImageRetriever(BaseRetriever): + """Image To Image Retriever for supervised retrieval task. + + Args: + image_encoder (Union[dict, List[dict]]): Encoder for extracting + features. + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + head (dict, optional): The head module to calculate loss from + processed features. See :mod:`mmpretrain.models.heads`. Notice + that if the head is not set, `loss` method cannot be used. + Defaults to None. + similarity_fn (Union[str, Callable]): The way that the similarity + is calculated. If `similarity` is callable, it is used directly + as the measure function. If it is a string, the appropriate + method will be used. The larger the calculated value, the + greater the similarity. Defaults to "cosine_similarity". + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in + :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + topk (int): Return the topk of the retrieval result. `-1` means + return all. Defaults to -1. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + image_encoder: Union[dict, List[dict]], + prototype: Union[DataLoader, dict, str, torch.Tensor], + head: Optional[dict] = None, + pretrained: Optional[str] = None, + similarity_fn: Union[str, Callable] = 'cosine_similarity', + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + topk: int = -1, + init_cfg: Optional[dict] = None): + + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super(ImageToImageRetriever, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(image_encoder, nn.Module): + image_encoder = MODELS.build(image_encoder) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.image_encoder = image_encoder + self.head = head + + self.similarity = similarity_fn + + assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), ( + 'The `prototype` in `ImageToImageRetriever` must be a path, ' + 'a torch.Tensor, a dataloader or a dataloader dict format config.') + self.prototype = prototype + self.prototype_inited = False + self.topk = topk + + @property + def similarity_fn(self): + """Returns a function that calculates the similarity.""" + # If self.similarity_way is callable, return it directly + if isinstance(self.similarity, Callable): + return self.similarity + + if self.similarity == 'cosine_similarity': + # a is a tensor with shape (N, C) + # b is a tensor with shape (M, C) + # "cosine_similarity" will get the matrix of similarity + # with shape (N, M). + # The higher the score is, the more similar is + return lambda a, b: torch.cosine_similarity( + a.unsqueeze(1), b.unsqueeze(0), dim=-1) + else: + raise RuntimeError(f'Invalid function "{self.similarity_fn}".') + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + return self.extract_feat(inputs) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: + Tensor: The output of encoder. + """ + + feat = self.image_encoder(inputs) + return feat + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def matching(self, inputs: torch.Tensor): + """Compare the prototype and calculate the similarity. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C). + Returns: + dict: a dictionary of score and prediction label based on fn. + """ + sim = self.similarity_fn(inputs, self.prototype_vecs) + sorted_sim, indices = torch.sort(sim, descending=True, dim=-1) + predictions = dict( + score=sim, pred_label=indices, pred_score=sorted_sim) + return predictions + + def predict(self, + inputs: tuple, + data_samples: Optional[List[DataSample]] = None, + **kwargs) -> List[DataSample]: + """Predict results from the extracted features. + + Args: + inputs (tuple): The features extracted from the backbone. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + Returns: + List[DataSample]: the raw data_samples with + the predicted results + """ + if not self.prototype_inited: + self.prepare_prototype() + + feats = self.extract_feat(inputs) + if isinstance(feats, tuple): + feats = feats[-1] + + # Matching of similarity + result = self.matching(feats) + return self._get_predictions(result, data_samples) + + def _get_predictions(self, result, data_samples): + """Post-process the output of retriever.""" + pred_scores = result['score'] + pred_labels = result['pred_label'] + if self.topk != -1: + topk = min(self.topk, pred_scores.size()[-1]) + pred_labels = pred_labels[:, :topk] + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + return data_samples + + def _get_prototype_vecs_from_dataloader(self, data_loader): + """get prototype_vecs from dataloader.""" + self.eval() + num = len(data_loader.dataset) + + prototype_vecs = None + for data_batch in track_on_main_process(data_loader, + 'Prepare prototype'): + data = self.data_preprocessor(data_batch, False) + feat = self(**data) + if isinstance(feat, tuple): + feat = feat[-1] + + if prototype_vecs is None: + dim = feat.shape[-1] + prototype_vecs = torch.zeros(num, dim) + for i, data_sample in enumerate(data_batch['data_samples']): + sample_idx = data_sample.get('sample_idx') + prototype_vecs[sample_idx] = feat[i] + + assert prototype_vecs is not None + dist.all_reduce(prototype_vecs) + return prototype_vecs + + def _get_prototype_vecs_from_path(self, proto_path): + """get prototype_vecs from prototype path.""" + data = [None] + if dist.is_main_process(): + data[0] = torch.load(proto_path) + dist.broadcast_object_list(data, src=0) + prototype_vecs = data[0] + assert prototype_vecs is not None + return prototype_vecs + + @torch.no_grad() + def prepare_prototype(self): + """Used in meta testing. This function will be called before the meta + testing. Obtain the vector based on the prototype. + + - torch.Tensor: The prototype vector is the prototype + - str: The path of the extracted feature path, parse data structure, + and generate the prototype feature vector set + - Dataloader or config: Extract and save the feature vectors according + to the dataloader + """ + device = next(self.image_encoder.parameters()).device + if isinstance(self.prototype, torch.Tensor): + prototype_vecs = self.prototype + elif isinstance(self.prototype, str): + prototype_vecs = self._get_prototype_vecs_from_path(self.prototype) + elif isinstance(self.prototype, (dict, DataLoader)): + loader = Runner.build_dataloader(self.prototype) + prototype_vecs = self._get_prototype_vecs_from_dataloader(loader) + + self.register_buffer( + 'prototype_vecs', prototype_vecs.to(device), persistent=False) + self.prototype_inited = True + + def dump_prototype(self, path): + """Save the features extracted from the prototype to specific path. + + Args: + path (str): Path to save feature. + """ + if not self.prototype_inited: + self.prepare_prototype() + torch.save(self.prototype_vecs, path) diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1052dedc4af6476b43c72314f0ee53f9ba66be50 --- /dev/null +++ b/mmpretrain/models/selfsup/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .barlowtwins import BarlowTwins +from .base import BaseSelfSupervisor +from .beit import VQKD, BEiT, BEiTPretrainViT +from .byol import BYOL +from .cae import CAE, CAEPretrainViT, DALLEEncoder +from .densecl import DenseCL +from .eva import EVA +from .itpn import iTPN, iTPNHiViT +from .mae import MAE, MAEHiViT, MAEViT +from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT +from .milan import MILAN, CLIPGenerator, MILANViT +from .mixmim import MixMIM, MixMIMPretrainTransformer +from .moco import MoCo +from .mocov3 import MoCoV3, MoCoV3ViT +from .simclr import SimCLR +from .simmim import SimMIM, SimMIMSwinTransformer +from .simsiam import SimSiam +from .spark import SparK +from .swav import SwAV + +__all__ = [ + 'BaseSelfSupervisor', + 'BEiTPretrainViT', + 'VQKD', + 'CAEPretrainViT', + 'DALLEEncoder', + 'MAEViT', + 'MAEHiViT', + 'iTPNHiViT', + 'iTPN', + 'HOGGenerator', + 'MaskFeatViT', + 'CLIPGenerator', + 'MILANViT', + 'MixMIMPretrainTransformer', + 'MoCoV3ViT', + 'SimMIMSwinTransformer', + 'MoCo', + 'MoCoV3', + 'BYOL', + 'SimCLR', + 'SimSiam', + 'BEiT', + 'CAE', + 'MAE', + 'MaskFeat', + 'MILAN', + 'MixMIM', + 'SimMIM', + 'EVA', + 'DenseCL', + 'BarlowTwins', + 'SwAV', + 'SparK', +] diff --git a/mmpretrain/models/selfsup/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..949a8059015affd2f564e8fccb335a7c38fd3e93 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/barlowtwins.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/barlowtwins.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d1e4bef6b16603b20b2907e0913051d6d70835 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/barlowtwins.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/base.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eebddeaa0eb657a45804acc942e4b56eb59550a8 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/base.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/beit.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/beit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbab301a566db098464df4965155672a51cc5d21 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/beit.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/byol.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/byol.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72415bccc098c27364409b96e6a9b6c56a39a0e0 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/byol.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/cae.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/cae.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c10c57312cac3d7252c1a106490eb511db366c0 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/cae.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/densecl.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/densecl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f3652cb67f9c305610896f1be94c5397675a642 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/densecl.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/eva.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/eva.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1e6455b630c864e04977e3101ac461e6c23afa7 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/eva.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/itpn.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/itpn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05af48910e7e62fad7b0390654d454809b2e7de6 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/itpn.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mae.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/mae.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d781744327f49f4d032d7b3ceff2e555836161e0 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mae.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/maskfeat.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/maskfeat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d42d1239e0ecf8744dfe66ff3ba0281df4c73e8 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/maskfeat.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/milan.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/milan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160faf660d88e37f3563f157a2da36009f71d5de Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/milan.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mixmim.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/mixmim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d9a1cfb2720e41093624d0b6e0a5e8d6286e04 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mixmim.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/moco.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/moco.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a218589f807be4efad45ccdd3092200f2afac96a Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/moco.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mocov3.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/mocov3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9083cd6edc7ab497de61e29d0d04301aff053a9c Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mocov3.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/simclr.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/simclr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ff3f489462ba0e19b38e52e0f38a9ee68b27dfa Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/simclr.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/simmim.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/simmim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3748d1bca694efc40509fca056d4de8dbecfd021 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/simmim.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/simsiam.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/simsiam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..704d395b00bb12b2eb1d2b379ed520e68ccd335c Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/simsiam.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/spark.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/spark.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d13eacd8a624be5fbf7c7de1ff909bae6350f3ce Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/spark.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/swav.cpython-38.pyc b/mmpretrain/models/selfsup/__pycache__/swav.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d80ff89dfcaa1f50d5591719bf032db0a19a7471 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/swav.cpython-38.pyc differ diff --git a/mmpretrain/models/selfsup/barlowtwins.py b/mmpretrain/models/selfsup/barlowtwins.py new file mode 100644 index 0000000000000000000000000000000000000000..4c75cd0caca6ab2dc4c4a14e365fda5daa9bdb83 --- /dev/null +++ b/mmpretrain/models/selfsup/barlowtwins.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class BarlowTwins(BaseSelfSupervisor): + """BarlowTwins. + + Implementation of `Barlow Twins: Self-Supervised Learning via Redundancy + Reduction `_. + Part of the code is borrowed from: + ``_. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + + z1 = self.neck(self.backbone(img_v1))[0] # NxC + z2 = self.neck(self.backbone(img_v2))[0] # NxC + + loss = self.head.loss(z1, z2) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/base.py b/mmpretrain/models/selfsup/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9d53a72871dff7b4fc59cd591686350026a875bb --- /dev/null +++ b/mmpretrain/models/selfsup/base.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta): + """BaseModel for Self-Supervised Learning. + + All self-supervised algorithms should inherit this module. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + target_generator: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + if target_generator is not None and not isinstance( + target_generator, nn.Module): + target_generator = MODELS.build(target_generator) + + self.backbone = backbone + self.neck = neck + self.head = head + self.target_generator = target_generator + + @property + def with_neck(self) -> bool: + """Check if the model has a neck module.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Check if the model has a head module.""" + return hasattr(self, 'head') and self.head is not None + + @property + def with_target_generator(self) -> bool: + """Check if the model has a target_generator module.""" + return hasattr( + self, 'target_generator') and self.target_generator is not None + + def forward(self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method currently accepts two modes: "tensor" and "loss": + + - "tensor": Forward the backbone network and return the feature + tensor(s) tensor without any post-processing, same as a common + PyTorch Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor or List[torch.Tensor]): The input tensor with + shape (N, C, ...) in general. + data_samples (List[DataSample], optional): The other data of + every samples. It's required for some algorithms + if ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The default behavior is extracting features from backbone. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + + Returns: + tuple | Tensor: The output feature tensor(s). + """ + x = self.backbone(inputs) + return x + + @abstractmethod + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + This is a abstract method, and subclass should overwrite this methods + if needed. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + raise NotImplementedError + + def get_layer_depth(self, param_name: str): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + Tuple[int, int]: The layer-wise depth and the max depth. + """ + if hasattr(self.backbone, 'get_layer_depth'): + return self.backbone.get_layer_depth(param_name, 'backbone.') + else: + raise NotImplementedError( + f"The backbone {type(self.backbone)} doesn't " + 'support `get_layer_depth` by now.') diff --git a/mmpretrain/models/selfsup/beit.py b/mmpretrain/models/selfsup/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..c301f7d5cae07370f26b4cd531190b8c3c90e24b --- /dev/null +++ b/mmpretrain/models/selfsup/beit.py @@ -0,0 +1,357 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +from einops import rearrange +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from torch import nn + +from mmpretrain.models.backbones import BEiTViT +from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class VQKD(BaseModule): + """Vector-Quantized Knowledge Distillation. + + The module only contains encoder and VectorQuantizer part + Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py + + Args: + encoder_config (dict): The config of encoder. + decoder_config (dict, optional): The config of decoder. Currently, + VQKD only support to build encoder. Defaults to None. + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + decay (float): The decay parameter of EMA. Defaults to 0.99. + beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. + quantize_kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + encoder_config: dict, + decoder_config: Optional[dict] = None, + num_embed: int = 8192, + embed_dims: int = 32, + decay: float = 0.99, + beta: float = 1.0, + quantize_kmeans_init: bool = True, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.encoder = BEiTViT(**encoder_config) + if decoder_config is not None: + self.decoder = BEiTViT(**decoder_config) + + self.quantize = NormEMAVectorQuantizer( + num_embed=num_embed, + embed_dims=embed_dims, + beta=beta, + decay=decay, + kmeans_init=quantize_kmeans_init, + ) + + # task layer + self.encode_task_layer = nn.Sequential( + nn.Linear(self.encoder.arch_settings['embed_dims'], + self.encoder.arch_settings['embed_dims']), nn.Tanh(), + nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims)) + + def get_tokens(self, x: torch.Tensor) -> dict: + """Get tokens for beit pre-training.""" + _, embed_ind, _ = self.encode(x) + output = {} + output['token'] = embed_ind.view(x.shape[0], -1) + output['input_img'] = x + + return output + + def encode( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode the input images and get corresponding results.""" + encoder_features = self.encoder(x)[0] + B, C, N1, N2 = encoder_features.shape + encoder_features = encoder_features.permute(0, 2, 3, + 1).reshape(B, N1 * N2, C) + + with torch.cuda.amp.autocast(enabled=False): + to_quantizer_features = self.encode_task_layer( + encoder_features.type_as(self.encode_task_layer[-1].weight)) + + N = to_quantizer_features.shape[1] + h, w = int(math.sqrt(N)), int(math.sqrt(N)) + + to_quantizer_features = rearrange( + to_quantizer_features, 'b (h w) c -> b c h w', h=h, + w=w) # reshape for quantizer + quantize, loss, embed_ind = self.quantize(to_quantizer_features) + + return quantize, embed_ind, loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The forward function. + + Currently, only support to get tokens. + """ + return self.get_tokens(x)['token'] + + +@MODELS.register_module() +class BEiTPretrainViT(BEiTViT): + """Vision Transformer for BEiT pre-training. + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + use_abs_pos_emb (bool): Whether or not use absolute position embedding. + Defaults to False. + use_rel_pos_bias (bool): Whether or not use relative position bias. + Defaults to False. + use_shared_rel_pos_bias (bool): Whether or not use shared relative + position bias. Defaults to True. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + frozen_stages: int = -1, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + layer_scale_init_value: int = 0.1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(padding=0), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + frozen_stages=frozen_stages, + use_abs_pos_emb=use_abs_pos_emb, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_rel_pos_bias=use_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=0.02) + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor]: + """The BEiT style forward function. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images, which is of shape (B x C x H x W). + mask (torch.Tensor, optional): Mask for input, which is of shape + (B x patch_resolution[0] x patch_resolution[1]). + + Returns: + Tuple[torch.Tensor]: Hidden features. + """ + if mask is None: + return super().forward(x) + + else: + x, patch_resolution = self.patch_embed(x) + + # replace the masked visual tokens by mask_token + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + self.shared_rel_pos_bias = self.rel_pos_bias().to( + mask.device) if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + +@MODELS.register_module() +class BEiT(BaseSelfSupervisor): + """BEiT v1/v2. + + Implementation of `BEiT: BERT Pre-Training of Image Transformers + `_ and `BEiT v2: Masked Image Modeling + with Vector-Quantized Visual Tokenizers + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs[0], mask) + + # inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(inputs[1]) + target = target.detach() + + if self.with_neck: + # BEiT v2 + feats, feats_cls_pt = self.neck( + img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias) + loss = self.head.loss(feats, feats_cls_pt, target, mask) + else: + # BEiT v1 + loss = self.head.loss(img_latent[0], target, mask) + + if isinstance(loss, torch.Tensor): + losses = dict(loss=loss) + return losses + elif isinstance(loss, Tuple): + # the loss_1 and loss_2 are general reconstruction loss (patch + # feature vectors from last layer of backbone) and early state + # reconstruction loss (patch feature vectors from intermediate + # layer of backbone) + loss_1, loss_2 = loss[0], loss[1] + losses = dict() + # the key with prefix 'loss', like loss_1 and loss_2, will be used + # as the final criterion + losses['loss_1'] = loss_1 + losses['loss_2'] = loss_2 + return losses diff --git a/mmpretrain/models/selfsup/byol.py b/mmpretrain/models/selfsup/byol.py new file mode 100644 index 0000000000000000000000000000000000000000..803e4005da8620b0e5a93fb29cb65e90a78f345f --- /dev/null +++ b/mmpretrain/models/selfsup/byol.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import CosineEMA +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class BYOL(BaseSelfSupervisor): + """BYOL. + + Implementation of `Bootstrap Your Own Latent: A New Approach to + Self-Supervised Learning `_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features + to compact feature vectors. + head (dict): Config dict for module of head functions. + base_momentum (float): The base momentum coefficient for the target + network. Defaults to 0.004. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + base_momentum: float = 0.004, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.target_net = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + # compute online features + proj_online_v1 = self.neck(self.backbone(img_v1))[0] + proj_online_v2 = self.neck(self.backbone(img_v2))[0] + # compute target features + with torch.no_grad(): + # update the target net + self.target_net.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + proj_target_v1 = self.target_net(img_v1)[0] + proj_target_v2 = self.target_net(img_v2)[0] + + loss_1 = self.head.loss(proj_online_v1, proj_target_v2) + loss_2 = self.head.loss(proj_online_v2, proj_target_v1) + + losses = dict(loss=2. * (loss_1 + loss_2)) + return losses diff --git a/mmpretrain/models/selfsup/cae.py b/mmpretrain/models/selfsup/cae.py new file mode 100644 index 0000000000000000000000000000000000000000..67ac09188e9bf97cdbea63378aa4facb1e8348ab --- /dev/null +++ b/mmpretrain/models/selfsup/cae.py @@ -0,0 +1,472 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Part of code is modified from BEiT +# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py +import math +from collections import OrderedDict +from functools import partial +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones import BEiTViT +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +class Conv2d(nn.Module): + """Rewrite Conv2d module according to DALL-E code.""" + + def __init__(self, + n_in: int, + n_out: int, + kw: int, + use_float16: bool = True, + device: torch.device = torch.device('cpu'), + requires_grad: bool = False) -> None: + super().__init__() + + w = torch.empty((n_out, n_in, kw, kw), + dtype=torch.float32, + device=device, + requires_grad=requires_grad) + w.normal_(std=1 / math.sqrt(n_in * kw**2)) + + b = torch.zeros((n_out, ), + dtype=torch.float32, + device=device, + requires_grad=requires_grad) + self.kw = kw + self.w, self.b = nn.Parameter(w), nn.Parameter(b) + self.use_float16 = use_float16 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_float16 and 'cuda' in self.w.device.type: + if x.dtype != torch.float16: + x = x.half() + + w, b = self.w.half(), self.b.half() + else: + if x.dtype != torch.float32: + x = x.float() + + w, b = self.w, self.b + + return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) + + +class EncoderBlock(nn.Module): + """Rewrite EncoderBlock module according to DALL-E code.""" + + def __init__(self, + n_in: int, + n_out: int, + n_layers: int, + device: torch.device = None, + requires_grad: bool = False) -> None: + super().__init__() + self.n_hid = n_out // 4 + self.post_gain = 1 / (n_layers**2) + + make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) + self.id_path = make_conv(n_in, n_out, + 1) if n_in != n_out else nn.Identity() + self.res_path = nn.Sequential( + OrderedDict([ + ('relu_1', nn.ReLU()), + ('conv_1', make_conv(n_in, self.n_hid, 3)), + ('relu_2', nn.ReLU()), + ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_3', nn.ReLU()), + ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_4', nn.ReLU()), + ('conv_4', make_conv(self.n_hid, n_out, 1)), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +@MODELS.register_module(name='DALL-E') +class DALLEEncoder(BaseModule): + """DALL-E Encoder for feature extraction. + + Args: + group_count (int): Number of groups in DALL-E encoder. Defaults to 4. + n_hid (int): Dimension of hidden layers. Defaults to 256. + n_blk_per_group (int): Number of blocks per group. Defaults to 2. + input_channels: (int): The channels of input images. Defaults to 3. + vocab_size (int): Vocabulary size, indicating the number of classes. + Defaults to 8192. + device (torch.device): Device of parameters. Defaults to + ``torch.device('cpu')``. + requires_grad (bool): Require gradient or not. Defaults to False. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + group_count: int = 4, + n_hid: int = 256, + n_blk_per_group: int = 2, + input_channels: int = 3, + vocab_size: int = 8192, + device: torch.device = torch.device('cpu'), + requires_grad: bool = False, + init_cfg: Union[dict, List[dict], None] = None): + super().__init__(init_cfg=init_cfg) + self.input_channels = input_channels + + blk_range = range(n_blk_per_group) + n_layers = group_count * n_blk_per_group + make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) + make_blk = partial( + EncoderBlock, + n_layers=n_layers, + device=device, + requires_grad=requires_grad) + + self.blocks = nn.Sequential( + OrderedDict([ + ('input', make_conv(input_channels, 1 * n_hid, 7)), + ('group_1', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid)) + for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_2', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(1 * n_hid if i == 0 else 2 * n_hid, + 2 * n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_3', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(2 * n_hid if i == 0 else 4 * n_hid, + 4 * n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_4', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(4 * n_hid if i == 0 else 8 * n_hid, + 8 * n_hid)) for i in blk_range], + ]))), + ('output', + nn.Sequential( + OrderedDict([ + ('relu', nn.ReLU()), + ('conv', + make_conv( + 8 * n_hid, vocab_size, 1, use_float16=False)), + ]))), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function of DALL-E encoder. + + Args: + x (torch.Tensor): The input images with shape (B, C, H, W). + + Returns: + torch.Tensor: The output with shape (B, vocab_size, h, w). + """ + x = x.float() + if len(x.shape) != 4: + raise ValueError(f'input shape {x.shape} is not 4d') + if x.shape[1] != self.input_channels: + raise ValueError(f'input has {x.shape[1]} channels but model \ + built for {self.input_channels}') + if x.dtype != torch.float32: + raise ValueError('input must have dtype torch.float32') + + return self.blocks(x) + + +@MODELS.register_module() +class CAEPretrainViT(BEiTViT): + """Vision Transformer for CAE pre-training and the implementation is based + on BEiTViT. + + Args: + arch (str | dict): Vision Transformer architecture. Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float, optional): The init value of gamma in + BEiTTransformerEncoderLayer. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + arch: str = 'b', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + bias: bool = 'qv_bias', + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + frozen_stages: int = -1, + use_abs_pos_emb: bool = True, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = False, + layer_scale_init_value: float = None, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: dict = [ + dict(type='Constant', val=1, layer=['LayerNorm']), + dict(type='TruncNormal', std=0.02, layer=['Conv2d']), + dict(type='Xavier', distribution='uniform', layer=['Linear']) + ] + ) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + bias=bias, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + frozen_stages=frozen_stages, + use_abs_pos_emb=use_abs_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + self.pos_embed.requires_grad = False + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # initialize position embedding in backbone + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + trunc_normal_(self.cls_token, std=.02) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> torch.Tensor: + """Generate features for masked images. + + This function generates mask images and get the hidden features for + visible patches. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (torch.Tensor, optional): Mask for input, which is of shape + B x L. + + Returns: + torch.Tensor: hidden features. + """ + if mask is None: + return super().forward(x) + + else: + x, _ = self.patch_embed(x) + batch_size, _, dim = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + + # NOTE: unmasked embeddings + x_unmasked = x[~mask].reshape(batch_size, -1, dim) + x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) + + pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1, + dim) + pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( + batch_size, -1, dim) + pos_embed_unmasked = torch.cat( + (pos_embed[:, :1], pos_embed_unmasked), dim=1) + x_unmasked = x_unmasked + pos_embed_unmasked + + x_unmasked = self.drop_after_pos(x_unmasked) + + for i, layer in enumerate(self.layers): + x_unmasked = layer(x=x_unmasked, rel_pos_bias=None) + + if i == len(self.layers) - 1 and self.final_norm: + x_unmasked = self.norm1(x_unmasked) + + return x_unmasked + + +@MODELS.register_module() +class CAE(BaseSelfSupervisor): + """CAE. + + Implementation of `Context Autoencoder for Self-Supervised Representation + Learning `_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of neck. + head (dict): Config dict for module of head functions. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + base_momentum (float): The base momentum coefficient for the target + network. Defaults to 0.0. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + target_generator: Optional[dict] = None, + base_momentum: float = 0.0, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + target_generator=target_generator, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + self.momentum = base_momentum + self.teacher = MODELS.build(backbone) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + # init the weights of teacher with those of backbone + for param_backbone, param_teacher in zip(self.backbone.parameters(), + self.teacher.parameters()): + param_teacher.detach() + param_teacher.data.copy_(param_backbone.data) + param_teacher.requires_grad = False + + def momentum_update(self) -> None: + """Momentum update of the teacher network.""" + for param_bacbone, param_teacher in zip(self.backbone.parameters(), + self.teacher.parameters()): + param_teacher.data = param_teacher.data * self.momentum + \ + param_bacbone.data * (1. - self.momentum) + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + mask = mask.flatten(1).to(torch.bool) + + unmasked = self.backbone(inputs[0], mask) + + # get the latent prediction for the masked patches + with torch.no_grad(): + # inputs[0] is the prediction image + latent_target = self.teacher(inputs[0], ~mask) + latent_target = latent_target[:, 1:, :] + self.momentum_update() + + pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1) + pos_embed_masked = pos_embed[:, + 1:][mask].reshape(inputs[0].shape[0], -1, + pos_embed.shape[-1]) + pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( + inputs[0].shape[0], -1, pos_embed.shape[-1]) + + # input the unmasked tokens and masked tokens to the decoder + logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked, + pos_embed_unmasked) + + logits = logits.view(-1, logits.shape[-1]) + # inputs[1] is the target image + logits_target = self.target_generator(inputs[1]) + loss_main, loss_align = self.head.loss(logits, logits_target, + latent_pred, latent_target, + mask) + losses = dict() + + losses['loss'] = loss_main + loss_align + losses['main'] = loss_main + losses['align'] = loss_align + return losses diff --git a/mmpretrain/models/selfsup/densecl.py b/mmpretrain/models/selfsup/densecl.py new file mode 100644 index 0000000000000000000000000000000000000000..c969af17fa921a119f6b05b5a319e104f6422494 --- /dev/null +++ b/mmpretrain/models/selfsup/densecl.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_gather +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class DenseCL(BaseSelfSupervisor): + """DenseCL. + + Implementation of `Dense Contrastive Learning for Self-Supervised Visual + Pre-Training `_. + Borrowed from the authors' code: ``_. + The loss_lambda warmup is in `engine/hooks/densecl_hook.py`. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to compact + feature vectors. + head (dict): Config dict for module of head functions. + queue_len (int): Number of negative keys maintained in the queue. + Defaults to 65536. + feat_dim (int): Dimension of compact feature vectors. Defaults to 128. + momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.999. + loss_lambda (float): Loss weight for the single and dense contrastive + loss. Defaults to 0.5. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + queue_len: int = 65536, + feat_dim: int = 128, + momentum: float = 0.001, + loss_lambda: float = 0.5, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.encoder_k = ExponentialMovingAverage( + nn.Sequential(self.backbone, self.neck), momentum) + + self.queue_len = queue_len + self.loss_lambda = loss_lambda + + # create the queue + self.register_buffer('queue', torch.randn(feat_dim, queue_len)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + # create the second queue for dense output + self.register_buffer('queue2', torch.randn(feat_dim, queue_len)) + self.queue2 = nn.functional.normalize(self.queue2, dim=0) + self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: + """Update queue.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None: + """Update queue2.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue2_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue2[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue2_ptr[0] = ptr + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + im_q = inputs[0] + im_k = inputs[1] + # compute query features + q_b = self.backbone(im_q) # backbone features + q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2 + q_b = q_b[0] + q_b = q_b.view(q_b.size(0), q_b.size(1), -1) + + q = nn.functional.normalize(q, dim=1) + q2 = nn.functional.normalize(q2, dim=1) + q_grid = nn.functional.normalize(q_grid, dim=1) + q_b = nn.functional.normalize(q_b, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + # update the key encoder + self.encoder_k.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + # shuffle for making use of BN + im_k, idx_unshuffle = batch_shuffle_ddp(im_k) + + k_b = self.encoder_k.module[0](im_k) # backbone features + k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2 + k_b = k_b[0] + k_b = k_b.view(k_b.size(0), k_b.size(1), -1) + + k = nn.functional.normalize(k, dim=1) + k2 = nn.functional.normalize(k2, dim=1) + k_grid = nn.functional.normalize(k_grid, dim=1) + k_b = nn.functional.normalize(k_b, dim=1) + + # undo shuffle + k = batch_unshuffle_ddp(k, idx_unshuffle) + k2 = batch_unshuffle_ddp(k2, idx_unshuffle) + k_grid = batch_unshuffle_ddp(k_grid, idx_unshuffle) + k_b = batch_unshuffle_ddp(k_b, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # feat point set sim + backbone_sim_matrix = torch.matmul(q_b.permute(0, 2, 1), k_b) + densecl_sim_ind = backbone_sim_matrix.max(dim=2)[1] # NxS^2 + + indexed_k_grid = torch.gather(k_grid, 2, + densecl_sim_ind.unsqueeze(1).expand( + -1, k_grid.size(1), -1)) # NxCxS^2 + densecl_sim_q = (q_grid * indexed_k_grid).sum(1) # NxS^2 + + # dense positive logits: NS^2X1 + l_pos_dense = densecl_sim_q.view(-1).unsqueeze(-1) + + q_grid = q_grid.permute(0, 2, 1) + q_grid = q_grid.reshape(-1, q_grid.size(2)) + # dense negative logits: NS^2xK + l_neg_dense = torch.einsum( + 'nc,ck->nk', [q_grid, self.queue2.clone().detach()]) + + loss_single = self.head.loss(l_pos, l_neg) + loss_dense = self.head.loss(l_pos_dense, l_neg_dense) + + losses = dict() + losses['loss_single'] = loss_single * (1 - self.loss_lambda) + losses['loss_dense'] = loss_dense * self.loss_lambda + + self._dequeue_and_enqueue(k) + self._dequeue_and_enqueue2(k2) + + return losses diff --git a/mmpretrain/models/selfsup/eva.py b/mmpretrain/models/selfsup/eva.py new file mode 100644 index 0000000000000000000000000000000000000000..30779bec491ae7c95b6540cdc7d71a875da572de --- /dev/null +++ b/mmpretrain/models/selfsup/eva.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class EVA(BaseSelfSupervisor): + """EVA. + + Implementation of `EVA: Exploring the Limits of Masked Visual + Representation Learning at Scale `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + clip_feature, _ = self.target_generator(inputs) + + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + + clip_feature = clip_feature[:, 1:, :] + loss = self.head.loss(pred, clip_feature, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py new file mode 100644 index 0000000000000000000000000000000000000000..85efd254053156c450b191d2d01a208882e874d9 --- /dev/null +++ b/mmpretrain/models/selfsup/itpn.py @@ -0,0 +1,356 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class iTPNHiViT(HiViT): + """HiViT for iTPN pre-training. + + Args: + img_size (int | tuple): Input image size. Defaults to 224. + patch_size (int | tuple): The patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim + in the first two stages. Defaults to 3. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in + the last stage. Defaults to 4. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): If True, add absolute position embedding to + the patch embedding. + rpe (bool): If True, add relative position embedding to + the patch embedding. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + reconstruction_type (str): The reconstruction of self-supervised + learning. Defaults to 'pixel'. + """ + + def __init__( + self, + arch='base', + img_size: int = 224, + patch_size: int = 16, + inner_patches: int = 4, + stem_mlp_ratio: int = 3., + mlp_ratio: int = 4., + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ape: bool = True, + rpe: bool = False, + layer_scale_init_value: float = 0.0, + mask_ratio: float = 0.75, + reconstruction_type: str = 'pixel', + ): + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + stem_mlp_ratio=stem_mlp_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + ape=ape, + rpe=rpe, + layer_scale_init_value=layer_scale_init_value) + + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + + assert reconstruction_type in ['pixel', 'clip'], \ + 'iTPN method only support `pixel` and `clip`, ' \ + f'but got `{reconstruction_type}`.' + self.reconstruction_type = reconstruction_type + self.num_patches = self.patch_embed.num_patches + + if reconstruction_type == 'clip': + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().apply(self._init_weights) + + if self.reconstruction_type == 'clip': + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + else: + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=False) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + if isinstance(layer, BlockWithRPE): + if layer.attn is not None: + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def masking_id(self, batch_size, mask_ratio): + N, L = batch_size, self.pos_embed.size(1) + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand( + N, L, device=self.pos_embed.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=self.pos_embed.device) + mask[:, :ids_keep.size(1)] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return ids_keep, ids_restore, mask + + def forward_pixel( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) + + x = self.patch_embed(x) + + x = torch.gather( + x, + dim=1, + index=ids_keep[:, :, None, None, + None].expand(-1, -1, *x.shape[2:])) + + outs = [] + for blk in self.blocks[:-self.num_main_blocks]: + if isinstance(blk, PatchMerge): + outs.append(x) + x = blk(x) + + x = x[..., 0, 0, :] + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + pos_embed = torch.gather( + pos_embed.expand(B, -1, -1), + dim=1, + index=ids_keep[:, :, None].expand(-1, -1, + pos_embed.shape[2]), + ) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x) + + outs.append(x) + + return (tuple(outs), mask, ids_restore) + + def forward_clip(self, + x: torch.Tensor, + mask: Optional[bool] = True) -> Tuple: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + x = self.patch_embed(x) + + outs = [] + for blk in self.blocks[:-self.num_main_blocks]: + if isinstance(blk, PatchMerge): + outs.append(x) + x = blk(x) + + x = x[..., 0, 0, :] + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + x = x + pos_embed + x = self.pos_drop(x) + + rpe_index = True if self.rpe else None + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x, rpe_index) + + outs.append(x) + + return tuple(outs) + + def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + + if self.reconstruction_type == 'pixel': + return self.forward_pixel(x, mask) + return self.forward_clip(x, mask) + + +@MODELS.register_module() +class iTPN(BaseSelfSupervisor): + """iTPN. + + Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid + Networks `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + if self.backbone.reconstruction_type == 'pixel': + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + + loss = self.head.loss(pred, inputs, mask) + else: + mask = torch.stack( + [data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs[0], mask) + + # inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(inputs[1])[0] + target = target.detach() + + # iTPN contains a neck module + feats = self.neck(img_latent) + loss = self.head.loss(feats, target[:, 1:, :], mask) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mae.py b/mmpretrain/models/selfsup/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..01bc5bc5134e02488556eacd8cfc30c2fae44fea --- /dev/null +++ b/mmpretrain/models/selfsup/mae.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch + +from mmpretrain.models import HiViT, VisionTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MAEViT(VisionTransformer): + """Vision Transformer for MAE pre-training. + + A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + # position embedding is not learnable during pretraining + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.projection.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + torch.nn.init.normal_(self.cls_token, std=.02) + + def random_masking( + self, + x: torch.Tensor, + mask_ratio: float = 0.75 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: masked image, mask + and the ids to restore original image. + + - ``x_masked`` (torch.Tensor): masked image. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return (x, mask, ids_restore) + + +@MODELS.register_module() +class MAE(BaseSelfSupervisor): + """MAE. + + Implementation of `Masked Autoencoders Are Scalable Vision Learners + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + loss = self.head.loss(pred, inputs, mask) + losses = dict(loss=loss) + return losses + + +@MODELS.register_module() +class MAEHiViT(HiViT): + """HiViT for MAE pre-training. + + A PyTorch implement of: `HiViT: A Simple and More Efficient Design + of Hierarchical Vision Transformer `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + Defaults to 4, to downsample 4x at the first stage + inner_patches (int): The inner patches within a token + Defaults to 4 + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): the absolute position embedding + rpe (bool): the relative position embedding + Defaults to False + layer_scale_init_value (float): the layer scale init value + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + inner_patches: int = 4, + out_indices: Union[list, int] = [23], + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ape: bool = True, + rpe: bool = False, + layer_scale_init_value: float = 0.0, + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + ape=ape, + rpe=rpe, + layer_scale_init_value=layer_scale_init_value, + init_cfg=init_cfg) + + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_embed.num_patches + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding.""" + super().apply(self._init_weights) + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=False) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def masking_id( + self, batch_size, + mask_ratio) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + batch_size: The batch size of input data + mask_ratio: The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: the ids + for the tokens retained, the ids to restore original image, + and the mask + """ + N, L = batch_size, self.pos_embed.size(1) + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand( + N, L, device=self.pos_embed.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=self.pos_embed.device) + mask[:, :ids_keep.size(1)] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return ids_keep, ids_restore, mask + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) + + x = self.patch_embed(x) + + x = torch.gather( + x, + dim=1, + index=ids_keep[:, :, None, None, + None].expand(-1, -1, *x.shape[2:])) + + for blk in self.blocks[:-self.num_main_blocks]: + x = blk(x) + + x = x[..., 0, 0, :] + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + pos_embed = torch.gather( + pos_embed.expand(B, -1, -1), + dim=1, + index=ids_keep[:, :, None].expand(-1, -1, + pos_embed.shape[2]), + ) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x) + + return (x, mask, ids_restore) diff --git a/mmpretrain/models/selfsup/maskfeat.py b/mmpretrain/models/selfsup/maskfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9f0b296c44cdffe7f2a40caae04de0104abd60 --- /dev/null +++ b/mmpretrain/models/selfsup/maskfeat.py @@ -0,0 +1,336 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Sequence, Union + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models import VisionTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class HOGGenerator(BaseModule): + """Generate HOG feature for images. + + This module is used in MaskFeat to generate HOG feature. The code is + modified from file `slowfast/models/operators.py + `_. + Here is the link of `HOG wikipedia + `_. + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16) -> None: + super().__init__() + self.nbins = nbins + self.pool = pool + self.pi = math.pi + weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous() + weight_y = weight_x.transpose(2, 3).contiguous() + self.register_buffer('weight_x', weight_x) + self.register_buffer('weight_y', weight_y) + + self.gaussian_window = gaussian_window + if gaussian_window: + gaussian_kernel = self.get_gaussian_kernel(gaussian_window, + gaussian_window // 2) + self.register_buffer('gaussian_kernel', gaussian_kernel) + + def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor: + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor: + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + kernel_1d = _gaussian_fn(kernlen, std) + kernel_2d = kernel_1d[:, None] * kernel_1d[None, :] + return kernel_2d / kernel_2d.sum() + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + """Reshape HOG Features for output.""" + hog_feat = hog_feat.flatten(1, 2) + self.unfold_size = hog_feat.shape[-1] // 14 + hog_feat = hog_feat.permute(0, 2, 3, 1) + hog_feat = hog_feat.unfold(1, self.unfold_size, + self.unfold_size).unfold( + 2, self.unfold_size, self.unfold_size) + hog_feat = hog_feat.flatten(1, 2).flatten(2) + return hog_feat + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate hog feature for each batch images. + + Args: + x (torch.Tensor): Input images of shape (N, 3, H, W). + + Returns: + torch.Tensor: Hog features. + """ + # input is RGB image with shape [B 3 H W] + self.h, self.w = x.size(-2), x.size(-1) + x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') + gx_rgb = F.conv2d( + x, self.weight_x, bias=None, stride=1, padding=0, groups=3) + gy_rgb = F.conv2d( + x, self.weight_y, bias=None, stride=1, padding=0, groups=3) + norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) + phase = torch.atan2(gx_rgb, gy_rgb) + phase = phase / self.pi * self.nbins # [-9, 9] + + b, c, h, w = norm_rgb.shape + out = torch.zeros((b, c, self.nbins, h, w), + dtype=torch.float, + device=x.device) + phase = phase.view(b, c, 1, h, w) + norm_rgb = norm_rgb.view(b, c, 1, h, w) + if self.gaussian_window: + if h != self.gaussian_window: + assert h % self.gaussian_window == 0, 'h {} gw {}'.format( + h, self.gaussian_window) + repeat_rate = h // self.gaussian_window + temp_gaussian_kernel = self.gaussian_kernel.repeat( + [repeat_rate, repeat_rate]) + else: + temp_gaussian_kernel = self.gaussian_kernel + norm_rgb *= temp_gaussian_kernel + + out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) + + out = out.unfold(3, self.pool, self.pool) + out = out.unfold(4, self.pool, self.pool) + out = out.sum(dim=[-1, -2]) + + self.out = F.normalize(out, p=2, dim=2) + + return self._reshape(self.out) + + def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray: + """Generate HOG image according to HOG features.""" + assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \ + 'Check the input batch size and the channcel number, only support'\ + '"batch_size = 1".' + hog_image = np.zeros([self.h, self.w]) + cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu()) + cell_width = self.pool / 2 + max_mag = np.array(cell_gradient).max() + angle_gap = 360 / self.nbins + + for x in range(cell_gradient.shape[1]): + for y in range(cell_gradient.shape[2]): + cell_grad = cell_gradient[:, x, y] + cell_grad /= max_mag + angle = 0 + for magnitude in cell_grad: + angle_radian = math.radians(angle) + x1 = int(x * self.pool + + magnitude * cell_width * math.cos(angle_radian)) + y1 = int(y * self.pool + + magnitude * cell_width * math.sin(angle_radian)) + x2 = int(x * self.pool - + magnitude * cell_width * math.cos(angle_radian)) + y2 = int(y * self.pool - + magnitude * cell_width * math.sin(angle_radian)) + magnitude = 0 if magnitude < 0 else magnitude + cv2.line(hog_image, (y1, x1), (y2, x2), + int(255 * math.sqrt(magnitude))) + angle += angle_gap + return hog_image + + +@MODELS.register_module() +class MaskFeatViT(VisionTransformer): + """Vision Transformer for MaskFeat pre-training. + + A PyTorch implement of: `Masked Feature Prediction for Self-Supervised + Visual Pre-Training `_. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.parameter.Parameter( + torch.zeros(1, 1, self.embed_dims), requires_grad=True) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, mask token and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + nn.init.trunc_normal_(self.cls_token, std=.02) + nn.init.trunc_normal_(self.mask_token, std=.02) + nn.init.trunc_normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m: torch.nn.Module) -> None: + if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> torch.Tensor: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor, optional): Input masks. + + Returns: + torch.Tensor: Features with cls_tokens. + """ + if mask is None: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + + # masking: length -> length * mask_ratio + B, L, _ = x.shape + mask_tokens = self.mask_token.expand(B, L, -1) + mask = mask.unsqueeze(-1) + x = x * (1 - mask.int()) + mask_tokens * mask + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + return x + + +@MODELS.register_module() +class MaskFeat(BaseSelfSupervisor): + """MaskFeat. + + Implementation of `Masked Feature Prediction for Self-Supervised Visual + Pre-Training `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + mask = mask.flatten(1).bool() + + latent = self.backbone(inputs, mask) + B, L, C = latent.shape + pred = self.neck((latent.view(B * L, C), )) + pred = pred[0].view(B, L, -1) + hog = self.target_generator(inputs) + + # remove cls_token before compute loss + loss = self.head.loss(pred[:, 1:], hog, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/milan.py b/mmpretrain/models/selfsup/milan.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf86737af3499e6f6309aa5c5ddadef00f63740 --- /dev/null +++ b/mmpretrain/models/selfsup/milan.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.runner.checkpoint import _load_checkpoint + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_clip_model +from .base import BaseSelfSupervisor +from .mae import MAEViT + + +@MODELS.register_module() +class CLIPGenerator(nn.Module): + """Get the features and attention from the last layer of CLIP. + + This module is used to generate target features in masked image modeling. + + Args: + tokenizer_path (str): The path of the checkpoint of CLIP. + """ + + def __init__(self, tokenizer_path: str) -> None: + super().__init__() + self.tokenizer_path = tokenizer_path + self.tokenizer = build_clip_model( + _load_checkpoint(self.tokenizer_path), False) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the features and attention from the last layer of CLIP. + + Args: + x (torch.Tensor): The input image, which is of shape (N, 3, H, W). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The features and attention from + the last layer of CLIP, which are of shape (N, L, C) and (N, L, L), + respectively. + """ + # use the visual branch of CLIP to get the features + assert self.tokenizer is not None, 'Please check whether the ' \ + '`self.tokenizer` is initialized correctly.' + + clip_features = self.tokenizer.encode_image(x) + return clip_features + + +@MODELS.register_module() +class MILANViT(MAEViT): + """Vision Transformer for MILAN pre-training. + + Implementation of the encoder for `MILAN: Masked Image Pretraining on + Language Assisted Representation `_. + + This module inherits from MAEViT and only overrides the forward function + and replace random masking with attention masking. + """ + + def attention_masking( + self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate attention mask for MILAN. + + This is what is different from MAEViT, which uses random masking. + Attention masking generates attention mask for MILAN, according to + importance. The higher the importance, the more likely the patch is + kept. + + Args: + x (torch.Tensor): Input images, which is of shape B x L x C. + mask_ratio (float): The ratio of patches to be masked. + importance (torch.Tensor): Importance of each patch, which is of + shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: + + - ``x_masked``: masked image + - ``ids_restore``: the ids to restore original image + - ``ids_keep``: ids of the kept patches + - ``ids_dump``: ids of the removed patches + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = importance.to(x.device) # large is keep, small is remove + + # sort noise for each sample + ids_shuffle = torch.multinomial(noise, L, replacement=False) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_dump = ids_shuffle[:, len_keep:] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, ids_restore, ids_keep, ids_dump + + def forward( + self, + x: torch.Tensor, + importance: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the + ``importance`` is ``None``, the function generates mask and masks some + patches randomly and get the hidden features for visible patches. The + mask is generated by importance. The higher the importance, the more + likely the patch is kept. The importance is calculated by CLIP. + The higher the CLIP score, the more likely the patch is kept. The CLIP + score is calculated by cross attention between the class token and all + other tokens from the last layer. + If the ``importance`` is ``torch.Tensor``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + importance (torch.Tensor, optional): Importance of each patch, + which is of shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: masked image, the ids to restore original + image, ids of the kept patches, ids of the removed patches. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + - ``ids_keep`` (torch.Tensor): ids of the kept patches. + - ``ids_dump`` (torch.Tensor): ids of the removed patches. + """ + if importance is None: + return super(MAEViT, self).forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, ids_restore, ids_keep, ids_dump = self.attention_masking( + x, self.mask_ratio, importance) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return x, ids_restore, ids_keep, ids_dump + + +@MODELS.register_module() +class MILAN(BaseSelfSupervisor): + """MILAN. + + Implementation of `MILAN: Masked Image Pretraining on Language Assisted + Representation `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, importance=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + clip_feature, importance = self.target_generator(inputs) + importance = importance[:, 0, 1:] + latent, ids_restore, ids_keep, ids_dump = self.backbone( + inputs, importance) + pred = self.neck(latent, ids_restore, ids_keep, ids_dump) + + loss = self.head.loss(pred, clip_feature) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mixmim.py b/mmpretrain/models/selfsup/mixmim.py new file mode 100644 index 0000000000000000000000000000000000000000..b202f836f64358369276a9b85795fb6eec769fb7 --- /dev/null +++ b/mmpretrain/models/selfsup/mixmim.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from mmpretrain.models.backbones import MixMIMTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MixMIMPretrainTransformer(MixMIMTransformer): + """MixMIM backbone for MixMIM pre-training. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + attn_drop_rate (float): Attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory + cost. Defaults to False. + mask_ratio (bool): The base ratio of total number of patches to be + masked. Defaults to 0.5. + range_mask_ratio (float): The range of mask ratio. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'base', + mlp_ratio: float = 4, + img_size: int = 224, + patch_size: int = 4, + in_channels: int = 3, + window_size: List = [14, 14, 14, 7], + qkv_bias: bool = True, + patch_cfg: dict = dict(), + norm_cfg: dict = dict(type='LN'), + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + attn_drop_rate: float = 0.0, + use_checkpoint: bool = False, + mask_ratio: float = 0.5, + range_mask_ratio: float = 0.0, + init_cfg: Optional[dict] = None) -> None: + + super().__init__( + arch=arch, + mlp_ratio=mlp_ratio, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + window_size=window_size, + qkv_bias=qkv_bias, + patch_cfg=patch_cfg, + norm_cfg=norm_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + use_checkpoint=use_checkpoint, + init_cfg=init_cfg) + + self.mask_ratio = mask_ratio + self.range_mask_ratio = range_mask_ratio + + def init_weights(self): + """Initialize position embedding, patch embedding.""" + super(MixMIMTransformer, self).init_weights() + + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.absolute_pos_embed.shape[-1], + cls_token=False) + self.absolute_pos_embed.data.copy_(pos_embed.float()) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def random_masking(self, + x: torch.Tensor, + mask_ratio: float = 0.5) -> Tuple[torch.Tensor]: + """Generate the mask for MixMIM Pretraining. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.5. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - mask_s1 (torch.Tensor): mask with stride of + self.encoder_stride // 8. + - mask_s2 (torch.Tensor): mask with stride of + self.encoder_stride // 4. + - mask_s3 (torch.Tensor): mask with stride of + self.encoder_stride // 2. + - mask (torch.Tensor): mask with stride of + self.encoder_stride. + """ + + B, C, H, W = x.shape + out_H = H // self.encoder_stride + out_W = W // self.encoder_stride + s3_H, s3_W = out_H * 2, out_W * 2 + s2_H, s2_W = out_H * 4, out_W * 4 + s1_H, s1_W = out_H * 8, out_W * 8 + + seq_l = out_H * out_W + # use a shared mask for a batch images + mask = torch.zeros([1, 1, seq_l], device=x.device) + + mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio) + noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1] + # ascend: small is keep, large is removed + mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)] + mask.scatter_(2, mask_idx, 1) + mask = mask.reshape(1, 1, out_H, out_W) + mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest') + mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest') + mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest') + + mask = mask.reshape(1, out_H * out_W, 1).contiguous() + mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous() + mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous() + mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous() + + return mask_s1, mask_s2, mask_s3, mask + + def forward(self, + x: torch.Tensor, + mask: Optional[bool] = True) -> Tuple[torch.Tensor]: + """Generate features for masked images. + + This function generates mask and masks some patches randomly and get + the hidden features for visible patches. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward containing + ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - x (torch.Tensor): hidden features, which is of shape + B x L x C. + - mask_s4 (torch.Tensor): the mask tensor for the last layer. + """ + if mask is None or False: + return super().forward(x) + + else: + mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking( + x, self.mask_ratio) + + x, _ = self.patch_embed(x) + + x = x * (1. - mask_s1) + x.flip(0) * mask_s1 + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for idx, layer in enumerate(self.layers): + if idx == 0: + x = layer(x, attn_mask=mask_s1) + elif idx == 1: + x = layer(x, attn_mask=mask_s2) + elif idx == 2: + x = layer(x, attn_mask=mask_s3) + elif idx == 3: + x = layer(x, attn_mask=mask_s4) + + x = self.norm(x) + + return x, mask_s4 + + +@MODELS.register_module() +class MixMIM(BaseSelfSupervisor): + """MixMIM. + + Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient + Visual Representation Learning. `_. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + + head.update(dict(patch_size=neck['encoder_stride'])) + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + latent, mask = self.backbone(inputs) + x_rec = self.neck(latent, mask) + loss = self.head.loss(x_rec, inputs, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/moco.py b/mmpretrain/models/selfsup/moco.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff4cf8fd6d0d6bca4724965d3b6d09543317748 --- /dev/null +++ b/mmpretrain/models/selfsup/moco.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_gather +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MoCo(BaseSelfSupervisor): + """MoCo. + + Implementation of `Momentum Contrast for Unsupervised Visual + Representation Learning `_. + Part of the code is borrowed from: + ``_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to compact feature + vectors. + head (dict): Config dict for module of head functions. + queue_len (int): Number of negative keys maintained in the + queue. Defaults to 65536. + feat_dim (int): Dimension of compact feature vectors. + Defaults to 128. + momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.001. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + queue_len: int = 65536, + feat_dim: int = 128, + momentum: float = 0.001, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.encoder_k = ExponentialMovingAverage( + nn.Sequential(self.backbone, self.neck), momentum) + + # create the queue + self.queue_len = queue_len + self.register_buffer('queue', torch.randn(feat_dim, queue_len)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: + """Update queue.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue_ptr[0] = ptr + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + im_q = inputs[0] + im_k = inputs[1] + # compute query features from encoder_q + q = self.neck(self.backbone(im_q))[0] # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + # update the key encoder + self.encoder_k.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + # shuffle for making use of BN + im_k, idx_unshuffle = batch_shuffle_ddp(im_k) + + k = self.encoder_k(im_k)[0] # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # undo shuffle + k = batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + loss = self.head.loss(l_pos, l_neg) + # update the queue + self._dequeue_and_enqueue(k) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mocov3.py b/mmpretrain/models/selfsup/mocov3.py new file mode 100644 index 0000000000000000000000000000000000000000..61b803387fdc129bc29056ee369fa3ad36c13e07 --- /dev/null +++ b/mmpretrain/models/selfsup/mocov3.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import reduce +from operator import mul +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones import VisionTransformer +from mmpretrain.models.utils import (build_2d_sincos_position_embedding, + to_2tuple) +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import CosineEMA +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MoCoV3ViT(VisionTransformer): + """Vision Transformer for MoCoV3 pre-training. + + A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for + Image Recognition at Scale `_. + + Part of the code is modified from: + ``_. + + Args: + stop_grad_conv1 (bool): whether to stop the gradient of + convolution layer in `PatchEmbed`. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + stop_grad_conv1: bool = False, + frozen_stages: int = -1, + norm_eval: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = None, + **kwargs) -> None: + + # add MoCoV3 ViT-small arch + self.arch_zoo.update( + dict.fromkeys( + ['mocov3-s', 'mocov3-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 1536, + })) + + super().__init__(init_cfg=init_cfg, **kwargs) + self.patch_size = kwargs['patch_size'] + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.init_cfg = init_cfg + + if stop_grad_conv1: + self.patch_embed.projection.weight.requires_grad = False + self.patch_embed.projection.bias.requires_grad = False + + self._freeze_stages() + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding, qkv layers and cls + token.""" + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + # Use fixed 2D sin-cos position embedding + pos_emb = build_2d_sincos_position_embedding( + patches_resolution=self.patch_resolution, + embed_dims=self.embed_dims, + cls_token=True) + self.pos_embed.data.copy_(pos_emb) + self.pos_embed.requires_grad = False + + # xavier_uniform initialization for PatchEmbed + val = math.sqrt( + 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + + self.embed_dims)) + nn.init.uniform_(self.patch_embed.projection.weight, -val, val) + nn.init.zeros_(self.patch_embed.projection.bias) + + # initialization for linear layers + for name, m in self.named_modules(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / + float(m.weight.shape[0] // 3 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + else: + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + nn.init.normal_(self.cls_token, std=1e-6) + + def _freeze_stages(self) -> None: + """Freeze patch_embed layer, some parameters and stages.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + self.cls_token.requires_grad = False + self.pos_embed.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i == (self.num_layers) and self.final_norm: + for param in getattr(self, 'norm1').parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> None: + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class MoCoV3(BaseSelfSupervisor): + """MoCo v3. + + Implementation of `An Empirical Study of Training Self-Supervised Vision + Transformers `_. + + Args: + backbone (dict): Config dict for module of backbone + neck (dict): Config dict for module of deep features to compact feature + vectors. + head (dict): Config dict for module of head functions. + base_momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.01. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + base_momentum: float = 0.01, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.momentum_encoder = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + view_1 = inputs[0] + view_2 = inputs[1] + + # compute query features, [N, C] each + q1 = self.neck(self.backbone(view_1))[0] + q2 = self.neck(self.backbone(view_2))[0] + + # compute key features, [N, C] each, no gradient + with torch.no_grad(): + # update momentum encoder + self.momentum_encoder.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + k1 = self.momentum_encoder(view_1)[0] + k2 = self.momentum_encoder(view_2)[0] + + loss = self.head.loss(q1, k2) + self.head.loss(q2, k1) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/simclr.py b/mmpretrain/models/selfsup/simclr.py new file mode 100644 index 0000000000000000000000000000000000000000..4b19ab4053de21a865fbaf864f654ff3ad8840f1 --- /dev/null +++ b/mmpretrain/models/selfsup/simclr.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple + +import torch +from mmengine.dist import all_gather, get_rank + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +class GatherLayer(torch.autograd.Function): + """Gather tensors from all process, supporting backward propagation.""" + + @staticmethod + def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]: + ctx.save_for_backward(input) + output = all_gather(input) + return tuple(output) + + @staticmethod + def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor: + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[get_rank()] + return grad_out + + +@MODELS.register_module() +class SimCLR(BaseSelfSupervisor): + """SimCLR. + + Implementation of `A Simple Framework for Contrastive Learning of Visual + Representations `_. + """ + + @staticmethod + def _create_buffer( + batch_size: int, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the mask and the index of positive samples. + + Args: + batch_size (int): The batch size. + device (torch.device): The device of backend. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - The mask for feature selection. + - The index of positive samples. + - The mask of negative samples. + """ + mask = 1 - torch.eye(batch_size * 2, dtype=torch.uint8).to(device) + pos_idx = ( + torch.arange(batch_size * 2).to(device), + 2 * torch.arange(batch_size, dtype=torch.long).unsqueeze(1).repeat( + 1, 2).view(-1, 1).squeeze().to(device)) + neg_mask = torch.ones((batch_size * 2, batch_size * 2 - 1), + dtype=torch.uint8).to(device) + neg_mask[pos_idx] = 0 + return mask, pos_idx, neg_mask + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + inputs = torch.stack(inputs, 1) + inputs = inputs.reshape((inputs.size(0) * 2, inputs.size(2), + inputs.size(3), inputs.size(4))) + x = self.backbone(inputs) + z = self.neck(x)[0] # (2n)xd + + z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10) + z = torch.cat(GatherLayer.apply(z), dim=0) # (2N)xd + assert z.size(0) % 2 == 0 + N = z.size(0) // 2 + s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N) + mask, pos_idx, neg_mask = self._create_buffer(N, s.device) + + # remove diagonal, (2N)x(2N-1) + s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1) + positive = s[pos_idx].unsqueeze(1) # (2N)x1 + + # select negative, (2N)x(2N-2) + negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1) + + loss = self.head.loss(positive, negative) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/simmim.py b/mmpretrain/models/selfsup/simmim.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf9494210c7a9d22853c4138542ba5c77d779f6 --- /dev/null +++ b/mmpretrain/models/selfsup/simmim.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models import SwinTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SimMIMSwinTransformer(SwinTransformer): + """Swin Transformer for SimMIM pre-training. + + Args: + Args: + arch (str | dict): Swin Transformer architecture + Defaults to 'T'. + img_size (int | tuple): The size of input image. + Defaults to 224. + in_channels (int): The num of input channels. + Defaults to 3. + drop_rate (float): Dropout rate after embedding. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. + Defaults to 0.1. + out_indices (tuple): Layers to be outputted. Defaults to (3, ). + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer at end + of backone. Defaults to dict(type='LN') + stage_cfgs (Sequence | dict): Extra config dict for each + stage. Defaults to empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to empty dict. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'T', + img_size: Union[Tuple[int, int], int] = 224, + in_channels: int = 3, + drop_rate: float = 0., + drop_path_rate: float = 0.1, + out_indices: tuple = (3, ), + use_abs_pos_embed: bool = False, + with_cp: bool = False, + frozen_stages: bool = -1, + norm_eval: bool = False, + norm_cfg: dict = dict(type='LN'), + stage_cfgs: Union[Sequence, dict] = dict(), + patch_cfg: dict = dict(), + pad_small_map: bool = False, + init_cfg: Optional[dict] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + in_channels=in_channels, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + out_indices=out_indices, + use_abs_pos_embed=use_abs_pos_embed, + with_cp=with_cp, + frozen_stages=frozen_stages, + norm_eval=norm_eval, + norm_cfg=norm_cfg, + stage_cfgs=stage_cfgs, + patch_cfg=patch_cfg, + pad_small_map=pad_small_map, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + trunc_normal_(self.mask_token, mean=0, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + """Initialize weights.""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> Sequence[torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor, optional): Masks for images. + + Returns: + tuple: A tuple containing features from multi-stages. + """ + if mask is None: + return super().forward(x) + + else: + x, hw_shape = self.patch_embed(x) + B, L, _ = x.shape + + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + +@MODELS.register_module() +class SimMIM(BaseSelfSupervisor): + """SimMIM. + + Implementation of `SimMIM: A Simple Framework for Masked Image Modeling + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs, mask) + img_rec = self.neck(img_latent[0]) + loss = self.head.loss(img_rec, inputs, mask) + losses = dict(loss=loss) + + return losses diff --git a/mmpretrain/models/selfsup/simsiam.py b/mmpretrain/models/selfsup/simsiam.py new file mode 100644 index 0000000000000000000000000000000000000000..a502cd770d0b497368dc7fc1d93caac01ec65db1 --- /dev/null +++ b/mmpretrain/models/selfsup/simsiam.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SimSiam(BaseSelfSupervisor): + """SimSiam. + + Implementation of `Exploring Simple Siamese Representation Learning + `_. The operation of fixing learning rate + of predictor is in `engine/hooks/simsiam_hook.py`. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + + z1 = self.neck(self.backbone(img_v1))[0] # NxC + z2 = self.neck(self.backbone(img_v2))[0] # NxC + + loss_1 = self.head.loss(z1, z2) + loss_2 = self.head.loss(z2, z1) + + losses = dict(loss=0.5 * (loss_1 + loss_2)) + return losses diff --git a/mmpretrain/models/selfsup/spark.py b/mmpretrain/models/selfsup/spark.py new file mode 100644 index 0000000000000000000000000000000000000000..d5570a5a9b17212aa400c3c6518a8e75a5c8c6c2 --- /dev/null +++ b/mmpretrain/models/selfsup/spark.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils.norm import build_norm_layer +from ..utils.sparse_modules import SparseHelper +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SparK(BaseSelfSupervisor): + """Implementation of SparK. + + Implementation of `Designing BERT for Convolutional Networks: Sparse and + Hierarchical Masked Modeling `_. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py + """ + + def __init__( + self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + input_size: int = 224, + downsample_raito: int = 32, + mask_ratio: float = 0.6, + enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'), + enc_dec_norm_dim: int = 2048, + init_cfg: Optional[dict] = None, + ) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + self.input_size = input_size + self.downsample_raito = downsample_raito + feature_map_size = input_size // downsample_raito + self.feature_map_size = feature_map_size + + self.mask_ratio = mask_ratio + self.len_keep = round(feature_map_size * feature_map_size * + (1 - mask_ratio)) + + self.enc_dec_norm_cfg = enc_dec_norm_cfg + self.enc_dec_norms = nn.ModuleList() + self.enc_dec_projectors = nn.ModuleList() + self.mask_tokens = nn.ParameterList() + + proj_out_dim = self.neck.feature_dim + for i in range(len(self.backbone.out_indices)): + enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg, + enc_dec_norm_dim) + self.enc_dec_norms.append(enc_dec_norm) + + kernel_size = 1 if i <= 0 else 3 + proj_layer = nn.Conv2d( + enc_dec_norm_dim, + proj_out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=True) + if i == 0 and enc_dec_norm_dim == proj_out_dim: + proj_layer = nn.Identity() + self.enc_dec_projectors.append(proj_layer) + + mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1)) + trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02) + self.mask_tokens.append(mask_token) + + enc_dec_norm_dim //= 2 + proj_out_dim //= 2 + feature_map_size *= 2 + + def mask(self, + shape: torch.Size, + device: Union[torch.device, str], + generator: Optional[torch.Generator] = None): + """Mask generation. + + Args: + shape (torch.Size): The shape of the input images. + device (Union[torch.device, str]): The device of the tensor. + generator (torch.Generator, optional): Generator for random + functions. Defaults to None + Returns: + torch.Tensor: The generated mask. + """ + B, C, H, W = shape + f = self.feature_map_size + idx = torch.rand(B, f * f, generator=generator).argsort(dim=1) + idx = idx[:, :self.len_keep].to(device) # (B, len_keep) + return torch.zeros( + B, f * f, dtype=torch.bool, device=device).scatter_( + dim=1, index=idx, value=True).view(B, 1, f, f) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # active mask of feature map, (B, 1, f, f) + active_mask_feature_map = self.mask(inputs.shape, inputs.device) + SparseHelper._cur_active = active_mask_feature_map + + # active mask of original input, (B, 1, H, W) + active_mask_origin = active_mask_feature_map.repeat_interleave( + self.downsample_raito, + 2).repeat_interleave(self.downsample_raito, 3) + masked_img = inputs * active_mask_origin + + # get hierarchical encoded sparse features in a list + # containing four feature maps + feature_maps = self.backbone(masked_img) + + # from the smallest feature map to the largest + feature_maps = list(feature_maps) + feature_maps.reverse() + + cur_active = active_mask_feature_map + feature_maps_to_dec = [] + for i, feature_map in enumerate(feature_maps): + if feature_map is not None: + # fill in empty positions with [mask] embeddings + feature_map = self.enc_dec_norms[i](feature_map) + mask_token = self.mask_tokens[i].expand_as(feature_map) + feature_map = torch.where( + cur_active.expand_as(feature_map), feature_map, + mask_token.to(feature_map.dtype)) + feature_map = self.enc_dec_projectors[i](feature_map) + feature_maps_to_dec.append(feature_map) + + # dilate the mask map + cur_active = cur_active.repeat_interleave( + 2, dim=2).repeat_interleave( + 2, dim=3) + + # decode and reconstruct + rec_img = self.neck(feature_maps_to_dec) + + # compute loss + loss = self.head(rec_img, inputs, active_mask_feature_map) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/swav.py b/mmpretrain/models/selfsup/swav.py new file mode 100644 index 0000000000000000000000000000000000000000..efe0eab483319bd2dfde8929a2285e684cd3fc38 --- /dev/null +++ b/mmpretrain/models/selfsup/swav.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SwAV(BaseSelfSupervisor): + """SwAV. + + Implementation of `Unsupervised Learning of Visual Features by Contrasting + Cluster Assignments `_. + + The queue is built in ``mmpretrain/engine/hooks/swav_hook.py``. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """Forward computation during training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + # multi-res forward passes + idx_crops = torch.cumsum( + torch.unique_consecutive( + torch.tensor([input.shape[-1] for input in inputs]), + return_counts=True)[1], 0) + start_idx = 0 + output = [] + for end_idx in idx_crops: + _out = self.backbone(torch.cat(inputs[start_idx:end_idx])) + output.append(_out) + start_idx = end_idx + output = self.neck(output) + + loss = self.head.loss(output) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/tta/__init__.py b/mmpretrain/models/tta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..568e64ffdc743b4694045f39a46deb5083b2688a --- /dev/null +++ b/mmpretrain/models/tta/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .score_tta import AverageClsScoreTTA + +__all__ = ['AverageClsScoreTTA'] diff --git a/mmpretrain/models/tta/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/tta/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7683e786756e018b658cbe4127b441e53aaee55 Binary files /dev/null and b/mmpretrain/models/tta/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/tta/__pycache__/score_tta.cpython-38.pyc b/mmpretrain/models/tta/__pycache__/score_tta.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a39227ae40e3c503da4a7ae2bc9f63e171cc433f Binary files /dev/null and b/mmpretrain/models/tta/__pycache__/score_tta.cpython-38.pyc differ diff --git a/mmpretrain/models/tta/score_tta.py b/mmpretrain/models/tta/score_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8a0786577c6cdb5076957df0ed60aac9d307cb --- /dev/null +++ b/mmpretrain/models/tta/score_tta.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.model import BaseTTAModel + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class AverageClsScoreTTA(BaseTTAModel): + + def merge_preds( + self, + data_samples_list: List[List[DataSample]], + ) -> List[DataSample]: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[List[DataSample]]): List of predictions + of all enhanced data. + + Returns: + List[DataSample]: Merged prediction. + """ + merged_data_samples = [] + for data_samples in data_samples_list: + merged_data_samples.append(self._merge_single_sample(data_samples)) + return merged_data_samples + + def _merge_single_sample(self, data_samples): + merged_data_sample: DataSample = data_samples[0].new() + merged_score = sum(data_sample.pred_score + for data_sample in data_samples) / len(data_samples) + merged_data_sample.set_pred_score(merged_score) + return merged_data_sample diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e59d71d524308cbda3f4f693d1fb066b4a5981fa --- /dev/null +++ b/mmpretrain/models/utils/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL +from .attention import (BEiTAttention, ChannelMultiheadAttention, + CrossMultiheadAttention, LeAttention, + MultiheadAttention, PromptMultiheadAttention, + ShiftWindowMSA, WindowMSA, WindowMSAV2) +from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix +from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp +from .channel_shuffle import channel_shuffle +from .clip_generator_helper import QuickGELU, build_clip_model +from .data_preprocessor import (ClsDataPreprocessor, + MultiModalDataPreprocessor, + SelfSupDataPreprocessor, + TwoNormDataPreprocessor, VideoDataPreprocessor) +from .ema import CosineEMA +from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, + resize_relative_position_bias_table) +from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple +from .inverted_residual import InvertedResidual +from .layer_scale import LayerScale +from .make_divisible import make_divisible +from .norm import GRN, LayerNorm2d, build_norm_layer +from .position_encoding import (ConditionalPositionEncoding, + PositionEncodingFourier, RotaryEmbeddingFast, + build_2d_sincos_position_embedding) +from .res_layer_extra_norm import ResLayerExtraNorm +from .se_layer import SELayer +from .sparse_modules import (SparseAvgPooling, SparseBatchNorm2d, SparseConv2d, + SparseHelper, SparseLayerNorm2D, SparseMaxPooling, + SparseSyncBatchNorm2d) +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .vector_quantizer import NormEMAVectorQuantizer + +__all__ = [ + 'channel_shuffle', + 'make_divisible', + 'InvertedResidual', + 'SELayer', + 'to_ntuple', + 'to_2tuple', + 'to_3tuple', + 'to_4tuple', + 'PatchEmbed', + 'PatchMerging', + 'HybridEmbed', + 'RandomBatchAugment', + 'ShiftWindowMSA', + 'is_tracing', + 'MultiheadAttention', + 'ConditionalPositionEncoding', + 'resize_pos_embed', + 'resize_relative_position_bias_table', + 'ClsDataPreprocessor', + 'Mixup', + 'CutMix', + 'ResizeMix', + 'BEiTAttention', + 'LayerScale', + 'WindowMSA', + 'WindowMSAV2', + 'ChannelMultiheadAttention', + 'PositionEncodingFourier', + 'LeAttention', + 'GRN', + 'LayerNorm2d', + 'build_norm_layer', + 'CrossMultiheadAttention', + 'build_2d_sincos_position_embedding', + 'PromptMultiheadAttention', + 'NormEMAVectorQuantizer', + 'build_clip_model', + 'batch_shuffle_ddp', + 'batch_unshuffle_ddp', + 'SelfSupDataPreprocessor', + 'TwoNormDataPreprocessor', + 'VideoDataPreprocessor', + 'CosineEMA', + 'ResLayerExtraNorm', + 'MultiModalDataPreprocessor', + 'QuickGELU', + 'SwiGLUFFN', + 'SwiGLUFFNFused', + 'RotaryEmbeddingFast', + 'SparseAvgPooling', + 'SparseConv2d', + 'SparseHelper', + 'SparseMaxPooling', + 'SparseBatchNorm2d', + 'SparseLayerNorm2D', + 'SparseSyncBatchNorm2d', +] + +if WITH_MULTIMODAL: + from .huggingface import (no_load_hf_pretrained_model, register_hf_model, + register_hf_tokenizer) + from .tokenizer import (Blip2Tokenizer, BlipTokenizer, FullTokenizer, + OFATokenizer) + + __all__.extend([ + 'BlipTokenizer', 'OFATokenizer', 'Blip2Tokenizer', 'register_hf_model', + 'register_hf_tokenizer', 'no_load_hf_pretrained_model', 'FullTokenizer' + ]) diff --git a/mmpretrain/models/utils/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca6731eed1b89b1a081ea6199e2ed428819606b Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/attention.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6c720dcedb022fa1dbc33f64df5c7416cc200f7 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/attention.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/batch_shuffle.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/batch_shuffle.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b2f524afa380fde3353c35d07805046f3eed2b8 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/batch_shuffle.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/box_utils.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/box_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ef6337e144ca001c5d89a5a73e7174405b5d99 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/box_utils.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/channel_shuffle.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/channel_shuffle.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d5791788646572d239d84b601579d696e4bfbdc Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/channel_shuffle.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/clip_generator_helper.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/clip_generator_helper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2cb40c052f54fdb878cfdc5e368a676ad96ad64 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/clip_generator_helper.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/data_preprocessor.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/data_preprocessor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12481d9a5904340c476af95f47700df965c4815a Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/data_preprocessor.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/ema.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/ema.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7609b2dae1892f2981eceed78e4199584ec787c7 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/ema.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/embed.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/embed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6638b3ee9e4262cca553e085d6d455d015cfc56 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/embed.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/helpers.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/helpers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..929ac393fb1fd80fb301fc883e66052a39b75d45 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/helpers.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/huggingface.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/huggingface.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a9da577ebb2805b23bd8c1e03f33c2c6176d093 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/huggingface.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/inverted_residual.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/inverted_residual.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a789cd8ccc7f919a5595cf58b67dc7b10f1c533b Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/inverted_residual.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/layer_scale.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/layer_scale.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..988b76a35986ed4d7a853214658190f22444576b Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/layer_scale.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/make_divisible.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/make_divisible.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad0136bbc3089133c565d40fb93126938889617f Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/make_divisible.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/norm.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561a131d83bce5cf43e25376f860b3213c5086bf Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/norm.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/position_encoding.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/position_encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ea041081fa57cf44ff476738d13dc91919effe Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/position_encoding.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/res_layer_extra_norm.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/res_layer_extra_norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e18759e5beb29eebda2de42d564b2225ef366e8 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/res_layer_extra_norm.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/se_layer.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/se_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7bedf6b949ee3e36235e39e16c51ca0346a66f6 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/se_layer.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/sparse_modules.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/sparse_modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84529b3fe2ff28e8fd6976ee75337fc6682f1245 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/sparse_modules.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/swiglu_ffn.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/swiglu_ffn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa0f0dae49507ce5ee9252f13ef3fdcbc59beec6 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/swiglu_ffn.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/tokenizer.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..447e6acee0cfe87c77b5d855b0eb31642d40c8dc Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/vector_quantizer.cpython-38.pyc b/mmpretrain/models/utils/__pycache__/vector_quantizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4b3b04981f8967dc370a7bfde0d8b9b84a320c6 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/vector_quantizer.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/attention.py b/mmpretrain/models/utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e92f6054dd83881b508ac5e87d9034cd86b3a36c --- /dev/null +++ b/mmpretrain/models/utils/attention.py @@ -0,0 +1,1129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import warnings +from functools import partial +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from .helpers import to_2tuple +from .layer_scale import LayerScale + +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +def scaled_dot_product_attention_pyimpl(query, + key, + value, + attn_mask=None, + dropout_p=0., + scale=None, + is_causal=False): + scale = scale or query.size(-1)**0.5 + if is_causal and attn_mask is not None: + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) + + attn_weight = query @ key.transpose(-2, -1) / scale + if attn_mask is not None: + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +if digit_version(torch.__version__) >= digit_version('2.0.0'): + scaled_dot_product_attention = F.scaled_dot_product_attention +else: + scaled_dot_product_attention = scaled_dot_product_attention_pyimpl + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + super(WindowMSA, self).init_weights() + + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class WindowMSAV2(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Based on implementation on Swin Transformer V2 original repo. Refers to + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py + for more details. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + attn_drop (float): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float): Dropout ratio of output. Defaults to 0. + cpb_mlp_hidden_dims (int): The hidden dimensions of the continuous + relative position bias network. Defaults to 512. + pretrained_window_size (tuple(int)): The height and width of the window + in pre-training. Defaults to (0, 0), which means not load + pretrained model. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + cpb_mlp_hidden_dims=512, + pretrained_window_size=(0, 0), + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + + # Use small network for continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear( + in_features=2, out_features=cpb_mlp_hidden_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear( + in_features=cpb_mlp_hidden_dims, + out_features=num_heads, + bias=False)) + + # Add learnable scalar for cosine attention + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # get relative_coords_table + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), + self.window_size[0], + dtype=torch.float32) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), + self.window_size[1], + dtype=torch.float32) + relative_coords_table = torch.stack( + torch_meshgrid([relative_coords_h, relative_coords_w])).permute( + 1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= ( + pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= ( + pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + self.register_buffer('relative_coords_table', relative_coords_table) + + # get pair-wise relative position index + # for each token inside the window + indexes_h = torch.arange(self.window_size[0]) + indexes_w = torch.arange(self.window_size[1]) + coordinates = torch.stack( + torch_meshgrid([indexes_h, indexes_w]), dim=0) # 2, Wh, Ww + coordinates = torch.flatten(coordinates, start_dim=1) # 2, Wh*Ww + # 2, Wh*Ww, Wh*Ww + relative_coordinates = coordinates[:, :, None] - coordinates[:, + None, :] + relative_coordinates = relative_coordinates.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + + relative_coordinates[:, :, 0] += self.window_size[ + 0] - 1 # shift to start from 0 + relative_coordinates[:, :, 1] += self.window_size[1] - 1 + relative_coordinates[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coordinates.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(embed_dims)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = ( + F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp( + self.logit_scale, max=np.log(1. / 0.01)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp( + self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +@MODELS.register_module() +class ShiftWindowMSA(BaseModule): + """Shift Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults to dict(type='DropPath', drop_prob=0.). + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + window_msa (Callable): To build a window multi-head attention module. + Defaults to :class:`WindowMSA`. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + **kwargs: Other keyword arguments to build the window multi-head + attention module. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + pad_small_map=False, + window_msa=WindowMSA, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + + self.shift_size = shift_size + self.window_size = window_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = window_msa( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(self.window_size), + **kwargs, + ) + + self.drop = build_dropout(dropout_layer) + self.pad_small_map = pad_small_map + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, f"The query length {L} doesn't match the input "\ + f'shape ({H}, {W}).' + query = query.view(B, H, W, C) + + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) == window_size: + # If not pad small feature map, avoid shifting when the window size + # is equal to the size of feature map. It's to align with the + # behavior of the original implementation. + shift_size = shift_size if self.pad_small_map else 0 + elif min(H, W) < window_size: + # In the original implementation, the window size will be shrunk + # to the size of feature map. The behavior is different with + # swin-transformer for downstream tasks. To support dynamic input + # shape, we don't allow this feature. + assert self.pad_small_map, \ + f'The input shape ({H}, {W}) is smaller than the window ' \ + f'size ({window_size}). Please set `pad_small_map=True`, or ' \ + 'decrease the `window_size`.' + + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if shift_size > 0: + query = torch.roll( + query, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + attn_mask = self.get_attn_mask((H_pad, W_pad), + window_size=window_size, + shift_size=shift_size, + device=query.device) + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(query, window_size) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, window_size, window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, + window_size) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + if H != H_pad or W != W_pad: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + + return x + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + @staticmethod + def get_attn_mask(hw_shape, window_size, shift_size, device=None): + if shift_size > 0: + img_mask = torch.zeros(1, *hw_shape, 1, device=device) + h_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = ShiftWindowMSA.window_partition( + img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) + attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + +class MultiheadAttention(BaseModule): + """Multi-head Attention Module. + + This module implements multi-head attention that supports different input + dims and embed dims. And it also supports a shortcut from ``value``, which + is useful if input dims is not the same with embed dims. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + use_layer_scale (bool): Whether to use layer scale. Defaults to False. + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + input_dims=None, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + qkv_bias=True, + qk_scale=None, + proj_bias=True, + v_shortcut=False, + use_layer_scale=False, + layer_scale_init_value=0., + init_cfg=None): + super(MultiheadAttention, self).__init__(init_cfg=init_cfg) + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + if qk_scale is not None: + self.scaled_dot_product_attention = partial( + scaled_dot_product_attention_pyimpl, + scale=self.head_dims**-0.5) + else: + self.scaled_dot_product_attention = scaled_dot_product_attention + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.out_drop = build_dropout(dropout_layer) + + if use_layer_scale: + warnings.warn('The `use_layer_scale` in `MultiheadAttention` will ' + 'be deprecated. Please use `layer_scale_init_value` ' + 'to control whether using layer scale or not.') + + if use_layer_scale or (layer_scale_init_value > 0): + layer_scale_init_value = layer_scale_init_value or 1e-5 + self.gamma1 = LayerScale( + embed_dims, layer_scale_init_value=layer_scale_init_value) + else: + self.gamma1 = nn.Identity() + + def forward(self, x): + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.gamma1(self.proj_drop(x))) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + The initial implementation is in MMSegmentation. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int, int]): The height and width of the window. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + bias (str): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + use_rel_pos_bias, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + if window_size is None: + assert not use_rel_pos_bias + else: + assert isinstance(window_size, tuple) + self.window_size = window_size + self.use_rel_pos_bias = use_rel_pos_bias + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + if self.use_rel_pos_bias: + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch_meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + def init_weights(self): + super().init_weights() + if self.use_rel_pos_bias: + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, rel_pos_bias=None): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + rel_pos_bias (tensor): input relative position bias with shape of + (num_heads, N, N). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + # use shared relative position bias + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class ChannelMultiheadAttention(BaseModule): + """Channel Multihead Self-attention Module. + + This module implements channel multi-head attention that supports different + input dims and embed dims. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shoutcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to False. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + qk_scale_type (str): The scale type of qk scale. + Defaults to 'learnable'. It can be 'learnable', 'fixed' or 'none'. + qk_scale (float, optional): If set qk_scale_type to 'none', this + should be specified with valid float number. Defaults to None. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + input_dims=None, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + qkv_bias=False, + proj_bias=True, + qk_scale_type='learnable', + qk_scale=None, + v_shortcut=False, + init_cfg=None): + super().__init__(init_cfg) + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + if qk_scale_type == 'learnable': + self.scale = nn.Parameter(torch.ones(num_heads, 1, 1)) + elif qk_scale_type == 'fixed': + self.scale = self.head_dims**-0.5 + elif qk_scale_type == 'none': + assert qk_scale is not None + self.scale = qk_scale + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.out_drop = build_dropout(dropout_layer) + + def forward(self, x): + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + + q, k, v = [item.transpose(-2, -1) for item in [qkv[0], qkv[1], qkv[2]]] + + q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, self.embed_dims) + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = qkv[2].squeeze(1) + x + return x + + +class LeAttention(BaseModule): + """LeViT Attention. Multi-head attention with attention bias, which is + proposed in `LeViT: a Vision Transformer in ConvNet’s Clothing for Faster + Inference`_ + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8. + key_dim (int): Dimension of key. Default: None. + attn_ratio (int): Ratio of attention heads. Default: 8. + resolution (tuple[int]): Input resolution. Default: (16, 16). + init_cfg (dict, optional): The Config for initialization. + """ + + def __init__(self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list( + itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer( + 'attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, + -1).split([self.key_dim, self.key_dim, self.d], + dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class CrossMultiheadAttention(BaseModule): + """Cross attention between queries and the union of keys and values. + + This module is different from ``MultiheadAttention``, for the attention + is computed between queries and the union of keys and values. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + """ + + def __init__(self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0.) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(embed_dims, embed_dims, bias=False) + self.k = nn.Linear(embed_dims, embed_dims, bias=False) + self.v = nn.Linear(embed_dims, embed_dims, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(embed_dims)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, + x: torch.Tensor, + k: torch.Tensor = None, + v: torch.Tensor = None) -> None: + """Forward function.""" + B, N, _ = x.shape + + N_k = k.shape[1] + N_v = v.shape[1] + + q_bias, k_bias, v_bias = None, None, None + if self.q_bias is not None: + q_bias = self.q_bias + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + v_bias = self.v_bias + + q = F.linear( + input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim) + k = F.linear( + input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim) + v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + + q = q.reshape(B, N, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_q, dim) + k = k.reshape(B, N_k, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_k, dim) + v = v.reshape(B, N_v, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_v, dim) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class PromptMultiheadAttention(MultiheadAttention): + """Prompt Multihead Attention for MILAN. + + This module is specific for the prompt encoder in MILAN. It will not update + the visible tokens from the encoder. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + return_attention (bool): If True, return the attention map, computed by + the cross attention between the class token and all other tokens. + Defaults to False. + init_cfg (Union[List[dict], dict], optional): The Config for + initialization. Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + input_dims: Optional[int] = None, + attn_drop: float = 0, + proj_drop: float = 0, + dropout_layer: dict = dict(type='Dropout', drop_prob=0.), + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + proj_bias: bool = True, + v_shortcut: bool = False, + use_layer_scale: bool = False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + input_dims=input_dims, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + v_shortcut=v_shortcut, + use_layer_scale=use_layer_scale, + init_cfg=init_cfg) + # no longer need qkv + del self.qkv + + # to project the mask tokens + self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias) + # to project al the tokens + self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias) + + def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """Forward function for `PromptMultiheadAttention`. + + Args: + x (torch.Tensor): Mask token features with shape N x L_m x C. + visible_tokens (torch.Tensor): The visible tokens features from + encoder with shape N x L_v x C. + ids_restore (torch.Tensor): The ids of all tokens in the original + image with shape N x L. + + Returns: + torch Tensor: Output features with shape N x L x C. + """ + x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1) + assert x_.shape[1] == ids_restore.shape[1] + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1) + + # full sequence shape + B, _, _ = x_.shape + q = self.q(x).reshape(B, x.shape[1], self.num_heads, + self.head_dims).permute(0, 2, 1, 3) + kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn_drop = self.attn_drop if self.training else 0. + attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.gamma1(self.proj_drop(x))) + return x diff --git a/mmpretrain/models/utils/batch_augments/__init__.py b/mmpretrain/models/utils/batch_augments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbc4e179608767f667ca1075e5134dbecb8c38d --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cutmix import CutMix +from .mixup import Mixup +from .resizemix import ResizeMix +from .wrapper import RandomBatchAugment + +__all__ = ('RandomBatchAugment', 'CutMix', 'Mixup', 'ResizeMix') diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b253b737ba70218eb280169f107414f98f76cb3 Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/cutmix.cpython-38.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/cutmix.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71f683d3e646cce010acecdd823f9a7b47e4b7fa Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/cutmix.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/mixup.cpython-38.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/mixup.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7531e7ffa0dbf4216c13893fde77178151a508c Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/mixup.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/resizemix.cpython-38.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/resizemix.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04adf00ebd5733292fa5c31b825f73f16b9c28fc Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/resizemix.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/wrapper.cpython-38.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/wrapper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e411403e054bd08907ed82e7061b9aa77609dd6a Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/wrapper.cpython-38.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/cutmix.py b/mmpretrain/models/utils/batch_augments/cutmix.py new file mode 100644 index 0000000000000000000000000000000000000000..665427bf5e2ff3a5ae9d656e7d642db8b72acabb --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/cutmix.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS +from .mixup import Mixup + + +@BATCH_AUGMENTS.register_module() +class CutMix(Mixup): + r"""CutMix batch agumentation. + + CutMix is a method to improve the network's generalization capability. It's + proposed in `CutMix: Regularization Strategy to Train Strong Classifiers + with Localizable Features ` + + With this method, patches are cut and pasted among training images where + the ground truth labels are also mixed proportionally to the area of the + patches. + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`Mixup`. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. + correct_lam (bool): Whether to apply lambda correction when cutmix bbox + clipped by image borders. Defaults to True. + + .. note :: + If the ``cutmix_minmax`` is None, how to generate the bounding-box of + patches according to the ``alpha``? + + First, generate a :math:`\lambda`, details can be found in + :class:`Mixup`. And then, the area ratio of the bounding-box + is calculated by: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} + """ + + def __init__(self, + alpha: float, + cutmix_minmax: Optional[List[float]] = None, + correct_lam: bool = True): + super().__init__(alpha=alpha) + + self.cutmix_minmax = cutmix_minmax + self.correct_lam = correct_lam + + def rand_bbox_minmax( + self, + img_shape: Tuple[int, int], + count: Optional[int] = None) -> Tuple[int, int, int, int]: + """Min-Max CutMix bounding-box Inspired by Darknet cutmix + implementation. It generates a random rectangular bbox based on min/max + percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and + .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + count (int, optional): Number of bbox to generate. Defaults to None + """ + assert len(self.cutmix_minmax) == 2 + img_h, img_w = img_shape + cut_h = np.random.randint( + int(img_h * self.cutmix_minmax[0]), + int(img_h * self.cutmix_minmax[1]), + size=count) + cut_w = np.random.randint( + int(img_w * self.cutmix_minmax[0]), + int(img_w * self.cutmix_minmax[1]), + size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + def rand_bbox(self, + img_shape: Tuple[int, int], + lam: float, + margin: float = 0., + count: Optional[int] = None) -> Tuple[int, int, int, int]: + """Standard CutMix bounding-box that generates a random square bbox + based on lambda value. This implementation includes support for + enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin + (reduce amount of box outside image). Defaults to 0. + count (int, optional): Number of bbox to generate. Defaults to None + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + def cutmix_bbox_and_lam(self, + img_shape: Tuple[int, int], + lam: float, + count: Optional[int] = None) -> tuple: + """Generate bbox and apply lambda correction. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + count (int, optional): Number of bbox to generate. Defaults to None + """ + if self.cutmix_minmax is not None: + yl, yu, xl, xu = self.rand_bbox_minmax(img_shape, count=count) + else: + yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count) + if self.correct_lam or self.cutmix_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[0] * img_shape[1]) + return (yl, yu, xl, xu), lam + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + img_shape = batch_inputs.shape[-2:] + index = torch.randperm(batch_size) + + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return batch_inputs, mixed_scores diff --git a/mmpretrain/models/utils/batch_augments/mixup.py b/mmpretrain/models/utils/batch_augments/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..bedb2c3e5b6e62595e50f7494eeda7c14827b391 --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/mixup.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS + + +@BATCH_AUGMENTS.register_module() +class Mixup: + r"""Mixup batch augmentation. + + Mixup is a method to reduces the memorization of corrupt labels and + increases the robustness to adversarial examples. It's proposed in + `mixup: Beyond Empirical Risk Minimization + `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + are in the note. + + Note: + The :math:`\alpha` (``alpha``) determines a random distribution + :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample + a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random + distribution. + """ + + def __init__(self, alpha: float): + assert isinstance(alpha, float) and alpha > 0 + + self.alpha = alpha + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return mixed_inputs, mixed_scores + + def __call__(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): + """Mix the batch inputs and batch data samples.""" + assert batch_score.ndim == 2, \ + 'The input `batch_score` should be a one-hot format tensor, '\ + 'which shape should be ``(N, num_classes)``.' + + mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score.float()) + return mixed_inputs, mixed_score diff --git a/mmpretrain/models/utils/batch_augments/resizemix.py b/mmpretrain/models/utils/batch_augments/resizemix.py new file mode 100644 index 0000000000000000000000000000000000000000..89cfb72033e75065502a594f17124eb1f471116f --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/resizemix.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from mmpretrain.registry import BATCH_AUGMENTS +from .cutmix import CutMix + + +@BATCH_AUGMENTS.register_module() +class ResizeMix(CutMix): + r"""ResizeMix Random Paste layer for a batch of data. + + The ResizeMix will resize an image to a small patch and paste it on another + image. It's proposed in `ResizeMix: Mixing Data with Preserved Object + Information and True Labels `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`Mixup`. + lam_min(float): The minimum value of lam. Defaults to 0.1. + lam_max(float): The maximum value of lam. Defaults to 0.8. + interpolation (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | + 'area'. Defaults to 'bilinear'. + prob (float): The probability to execute resizemix. It should be in + range [0, 1]. Defaults to 1.0. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. + correct_lam (bool): Whether to apply lambda correction when cutmix bbox + clipped by image borders. Defaults to True + **kwargs: Any other parameters accpeted by :class:`CutMix`. + + Note: + The :math:`\lambda` (``lam``) is the mixing ratio. It's a random + variable which follows :math:`Beta(\alpha, \alpha)` and is mapped + to the range [``lam_min``, ``lam_max``]. + + .. math:: + \lambda = \frac{Beta(\alpha, \alpha)} + {\lambda_{max} - \lambda_{min}} + \lambda_{min} + + And the resize ratio of source images is calculated by :math:`\lambda`: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} + """ + + def __init__(self, + alpha: float, + lam_min: float = 0.1, + lam_max: float = 0.8, + interpolation: str = 'bilinear', + cutmix_minmax: Optional[List[float]] = None, + correct_lam: bool = True): + super().__init__( + alpha=alpha, cutmix_minmax=cutmix_minmax, correct_lam=correct_lam) + self.lam_min = lam_min + self.lam_max = lam_max + self.interpolation = interpolation + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + lam = lam * (self.lam_max - self.lam_min) + self.lam_min + img_shape = batch_inputs.shape[-2:] + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate( + batch_inputs[index], + size=(y2 - y1, x2 - x1), + mode=self.interpolation, + align_corners=False) + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return batch_inputs, mixed_scores diff --git a/mmpretrain/models/utils/batch_augments/wrapper.py b/mmpretrain/models/utils/batch_augments/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..10e5304c3ca1a42428870ea5a00416007ca2e35c --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/wrapper.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Union + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS + + +class RandomBatchAugment: + """Randomly choose one batch augmentation to apply. + + Args: + augments (Callable | dict | list): configs of batch + augmentations. + probs (float | List[float] | None): The probabilities of each batch + augmentations. If None, choose evenly. Defaults to None. + + Example: + >>> import torch + >>> import torch.nn.functional as F + >>> from mmpretrain.models import RandomBatchAugment + >>> augments_cfg = [ + ... dict(type='CutMix', alpha=1.), + ... dict(type='Mixup', alpha=1.) + ... ] + >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) + >>> imgs = torch.rand(16, 3, 32, 32) + >>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10) + >>> imgs, label = batch_augment(imgs, label) + + .. note :: + + To decide which batch augmentation will be used, it picks one of + ``augments`` based on the probabilities. In the example above, the + probability to use CutMix is 0.5, to use Mixup is 0.3, and to do + nothing is 0.2. + """ + + def __init__(self, augments: Union[Callable, dict, list], probs=None): + if not isinstance(augments, (tuple, list)): + augments = [augments] + + self.augments = [] + for aug in augments: + if isinstance(aug, dict): + self.augments.append(BATCH_AUGMENTS.build(aug)) + else: + self.augments.append(aug) + + if isinstance(probs, float): + probs = [probs] + + if probs is not None: + assert len(augments) == len(probs), \ + '``augments`` and ``probs`` must have same lengths. ' \ + f'Got {len(augments)} vs {len(probs)}.' + assert sum(probs) <= 1, \ + 'The total probability of batch augments exceeds 1.' + self.augments.append(None) + probs.append(1 - sum(probs)) + + self.probs = probs + + def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor): + """Randomly apply batch augmentations to the batch inputs and batch + data samples.""" + aug_index = np.random.choice(len(self.augments), p=self.probs) + aug = self.augments[aug_index] + + if aug is not None: + return aug(batch_input, batch_score) + else: + return batch_input, batch_score.float() diff --git a/mmpretrain/models/utils/batch_shuffle.py b/mmpretrain/models/utils/batch_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b03c5fec5f99295daed2872feff73dfc238140 --- /dev/null +++ b/mmpretrain/models/utils/batch_shuffle.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.dist import all_gather, broadcast, get_rank + + +@torch.no_grad() +def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Batch shuffle, for making use of BatchNorm. + + Args: + x (torch.Tensor): Data in each GPU. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation. + - x_gather[idx_this]: Shuffled data. + - idx_unshuffle: Index for restoring. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all) + + # broadcast to all gpus + broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + +@torch.no_grad() +def batch_unshuffle_ddp(x: torch.Tensor, + idx_unshuffle: torch.Tensor) -> torch.Tensor: + """Undo batch shuffle. + + Args: + x (torch.Tensor): Data in each GPU. + idx_unshuffle (torch.Tensor): Index for restoring. + + Returns: + torch.Tensor: Output of unshuffle operation. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] diff --git a/mmpretrain/models/utils/box_utils.py b/mmpretrain/models/utils/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79db516c990f51a7c952404d932b6de022684fb4 --- /dev/null +++ b/mmpretrain/models/utils/box_utils.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torchvision.ops.boxes as boxes + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2.0, (y0 + y1) / 2.0, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +def box_iou(boxes1, boxes2): + """Return intersection-over-union (Jaccard index) between two sets of + boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for + every element in boxes1 and boxes2 + """ + return boxes.box_iou(boxes1, boxes2) + + +def generalized_box_iou(boxes1, boxes2): + """Return generalized intersection-over-union (Jaccard index) between two + sets of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU + values for every element in boxes1 and boxes2 + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + + return boxes.generalized_box_iou(boxes1, boxes2) diff --git a/mmpretrain/models/utils/channel_shuffle.py b/mmpretrain/models/utils/channel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..27006a8065db35a14c4207ce6613104374b064ad --- /dev/null +++ b/mmpretrain/models/utils/channel_shuffle.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def channel_shuffle(x, groups): + """Channel Shuffle operation. + + This function enables cross-group information flow for multiple groups + convolution layers. + + Args: + x (Tensor): The input tensor. + groups (int): The number of groups to divide the input tensor + in the channel dimension. + + Returns: + Tensor: The output tensor after channel shuffle operation. + """ + + batch_size, num_channels, height, width = x.size() + assert (num_channels % groups == 0), ('num_channels should be ' + 'divisible by groups') + channels_per_group = num_channels // groups + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(batch_size, -1, height, width) + + return x diff --git a/mmpretrain/models/utils/clip_generator_helper.py b/mmpretrain/models/utils/clip_generator_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4f67f0ed6976585a20e15787fc6b94c41082d33d --- /dev/null +++ b/mmpretrain/models/utils/clip_generator_helper.py @@ -0,0 +1,394 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from collections import OrderedDict +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.logging import MMLogger +from torch import nn + +from mmpretrain.registry import MODELS + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +@MODELS.register_module() +class QuickGELU(nn.Module): + """A faster version of GELU.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + """Residual Attention Block (RAB). + + This module implements the same function as the MultiheadAttention, + but with a different interface, which is mainly used + in CLIP. + + Args: + d_model (int): The feature dimension. + n_head (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + Defaults to None. + """ + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: Optional[torch.Tensor] = None, + return_attention: bool = False) -> None: + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.return_attention = return_attention + + def attention(self, x: torch.Tensor) -> torch.Tensor: + """Attention function.""" + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + if self.return_attention: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask) + else: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask)[0] + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Forward function.""" + if self.return_attention: + x_, attention = self.attention(self.ln_1(x)) + x = x + x_ + x = x + self.mlp(self.ln_2(x)) + return x, attention + else: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +class VisionTransformer(nn.Module): + """Vision Transformer for CLIP. + + Args: + input_resolution (int): The image size. + patch_size (int): The patch size. + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + out_dim (int): The output dimension. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + finetune=False, + average_targets: int = 1) -> None: + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.finetune = finetune + if finetune is False: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.average_targets = average_targets + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function.""" + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attention, z = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x, attention + + +class CLIP(nn.Module): + """CLIP. + + Args: + embed_dim (int): The embedding dimension. + image_resolution (int): The image size. + vision_layers (int): The number of layers in the vision transformer. + vision_width (int): The feature dimension in the vision transformer. + vision_patch_size (int): The patch size in the vision transformer. + context_length (int): The context length. + vocab_size (int): The vocabulary size. + transformer_width (int): The feature dimension in the text transformer. + transformer_heads (int): The number of attention heads in the + text transformer. + transformer_layers (int): The number of layers in the text transformer. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__( + self, + embed_dim: int, + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + finetune: bool = False, + average_targets: int = 1, + ) -> None: + super().__init__() + + self.context_length = context_length + + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + finetune=finetune, + average_targets=average_targets, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self) -> torch.Tensor: + """Build the attention mask.""" + # lazily create causal attention mask, with full attention between the + # vision tokens pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self) -> torch.dtype: + """Get the dtype.""" + return self.visual.conv1.weight.dtype + + def encode_image(self, + image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode the image. + + Get the feature and attention mask from the last layer of the visual + branch of CLIP. + + Args: + image (torch.Tensor): The image tensor with shape NCHW. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask. + """ + return self.visual(image.type(self.dtype)) + + +def build_clip_model(state_dict: dict, + finetune: bool = False, + average_targets: int = 1) -> nn.Module: + """Build the CLIP model. + + Args: + state_dict (dict): The pretrained state dict. + finetune (bool): Whether to fineturn the model. + average_targets (bool): Whether to average the target. + + Returns: + nn.Module: The CLIP model. + """ + vit = 'visual.proj' in state_dict + + if vit: + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + finetune, + average_targets, + ) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + msg = model.load_state_dict(state_dict, strict=False) + MMLogger.get_current_instance().info(f'Load CLIP model: {msg}') + return model.eval() diff --git a/mmpretrain/models/utils/data_preprocessor.py b/mmpretrain/models/utils/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..c407bd4c9361b9fae329854d4a36dab929fef143 --- /dev/null +++ b/mmpretrain/models/utils/data_preprocessor.py @@ -0,0 +1,620 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from mmengine.model import (BaseDataPreprocessor, ImgDataPreprocessor, + stack_batch) + +from mmpretrain.registry import MODELS +from mmpretrain.structures import (DataSample, MultiTaskDataSample, + batch_label_to_onehot, cat_batch_labels, + tensor_split) +from .batch_augments import RandomBatchAugment + + +@MODELS.register_module() +class ClsDataPreprocessor(BaseDataPreprocessor): + """Image pre-processor for classification tasks. + + Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + to_onehot (bool): Whether to generate one-hot format gt-labels and set + to data samples. Defaults to False. + num_classes (int, optional): The number of classes. Defaults to None. + batch_augments (dict, optional): The batch augmentations settings, + including "augments" and "probs". For more details, see + :class:`mmpretrain.models.RandomBatchAugment`. + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + to_onehot: bool = False, + num_classes: Optional[int] = None, + batch_augments: Optional[dict] = None): + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.to_onehot = to_onehot + self.num_classes = num_classes + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + if batch_augments: + self.batch_augments = RandomBatchAugment(**batch_augments) + if not self.to_onehot: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().info( + 'Because batch augmentations are enabled, the data ' + 'preprocessor automatically enables the `to_onehot` ' + 'option to generate one-hot format labels.') + self.to_onehot = True + else: + self.batch_augments = None + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + inputs = self.cast_data(data['inputs']) + + if isinstance(inputs, torch.Tensor): + # The branch if use `default_collate` as the collate_fn in the + # dataloader. + + # ------ To RGB ------ + if self.to_rgb and inputs.size(1) == 3: + inputs = inputs.flip(1) + + # -- Normalization --- + inputs = inputs.float() + if self._enable_normalize: + inputs = (inputs - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = inputs.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + else: + # The branch if use `pseudo_collate` as the collate_fn in the + # dataloader. + + processed_inputs = [] + for input_ in inputs: + # ------ To RGB ------ + if self.to_rgb and input_.size(0) == 3: + input_ = input_.flip(0) + + # -- Normalization --- + input_ = input_.float() + if self._enable_normalize: + input_ = (input_ - self.mean) / self.std + + processed_inputs.append(input_) + # Combine padding and stack + inputs = stack_batch(processed_inputs, self.pad_size_divisor, + self.pad_value) + + data_samples = data.get('data_samples', None) + sample_item = data_samples[0] if data_samples is not None else None + + if isinstance(sample_item, DataSample): + batch_label = None + batch_score = None + + if 'gt_label' in sample_item: + gt_labels = [sample.gt_label for sample in data_samples] + batch_label, label_indices = cat_batch_labels(gt_labels) + batch_label = batch_label.to(self.device) + if 'gt_score' in sample_item: + gt_scores = [sample.gt_score for sample in data_samples] + batch_score = torch.stack(gt_scores).to(self.device) + elif self.to_onehot and 'gt_label' in sample_item: + assert batch_label is not None, \ + 'Cannot generate onehot format labels because no labels.' + num_classes = self.num_classes or sample_item.get( + 'num_classes') + assert num_classes is not None, \ + 'Cannot generate one-hot format labels because not set ' \ + '`num_classes` in `data_preprocessor`.' + batch_score = batch_label_to_onehot( + batch_label, label_indices, num_classes).to(self.device) + + # ----- Batch Augmentations ---- + if (training and self.batch_augments is not None + and batch_score is not None): + inputs, batch_score = self.batch_augments(inputs, batch_score) + + # ----- scatter labels and scores to data samples --- + if batch_label is not None: + for sample, label in zip( + data_samples, tensor_split(batch_label, + label_indices)): + sample.set_gt_label(label) + if batch_score is not None: + for sample, score in zip(data_samples, batch_score): + sample.set_gt_score(score) + elif isinstance(sample_item, MultiTaskDataSample): + data_samples = self.cast_data(data_samples) + + return {'inputs': inputs, 'data_samples': data_samples} + + +@MODELS.register_module() +class SelfSupDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for operations, like normalization and bgr to rgb. + + Compared with the :class:`mmengine.ImgDataPreprocessor`, this module + supports ``inputs`` as torch.Tensor or a list of torch.Tensor. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + + self._channel_conversion = to_rgb or bgr_to_rgb or rgb_to_bgr + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + assert isinstance(data, + dict), 'Please use default_collate in dataloader, \ + instead of pseudo_collate.' + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + # Here is what is different from :class:`mmengine.ImgDataPreprocessor` + # Since there are multiple views for an image for some algorithms, + # e.g. SimCLR, each item in inputs is a list, containing multi-views + # for an image. + if isinstance(batch_inputs, list): + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # normalization. + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + else: + # channel transform + if self._channel_conversion: + batch_inputs = batch_inputs[:, [2, 1, 0], ...] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = batch_inputs.float() + + # normalization. + if self._enable_normalize: + batch_inputs = (batch_inputs - self.mean) / self.std + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class TwoNormDataPreprocessor(SelfSupDataPreprocessor): + """Image pre-processor for CAE, BEiT v1/v2, etc. + + Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module + will normalize the prediction image and target image with different + normalization parameters. + + Args: + mean (Sequence[float or int], optional): The pixel mean of image + channels. If ``to_rgb=True`` it means the mean value of R, G, B + channels. If the length of `mean` is 1, it means all channels have + the same mean value, or the input is a gray image. If it is not + specified, images will not be normalized. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation of + image channels. If ``to_rgb=True`` it means the standard deviation + of R, G, B channels. If the length of `std` is 1, it means all + channels have the same standard deviation, or the input is a gray + image. If it is not specified, images will not be normalized. + Defaults to None. + second_mean (Sequence[float or int], optional): The description is + like ``mean``, it can be customized for targe image. Defaults to + None. + second_std (Sequence[float or int], optional): The description is + like ``std``, it can be customized for targe image. Defaults to + None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + non_blocking (bool): Whether block current process when transferring + data to device. Defaults to False. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + second_mean: Sequence[Union[float, int]] = None, + second_std: Sequence[Union[float, int]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + to_rgb=to_rgb, + non_blocking=non_blocking) + assert (second_mean is not None) and (second_std is not None), ( + 'mean and std should not be None while using ' + '`TwoNormDataPreprocessor`') + assert len(second_mean) == 3 or len(second_mean) == 1, ( + '`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(second_mean)} values') + assert len(second_std) == 3 or len(second_std) == 1, ( + '`std` should have 1 or 3 values, to be compatible with RGB ' + f'or gray image, but got {len(std)} values') + + self.register_buffer('second_mean', + torch.tensor(second_mean).view(-1, 1, 1), False) + self.register_buffer('second_std', + torch.tensor(second_std).view(-1, 1, 1), False) + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization and bgr2rgb conversion based on + ``BaseDataPreprocessor``. The ``batch_inputs`` in forward function is a + list. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # Normalization. Here is what is different from + # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target + # image and prediction image with different normalization params + if self._enable_normalize: + batch_inputs = [ + (batch_inputs[0] - self.mean) / self.std, + (batch_inputs[1] - self.second_mean) / self.second_std + ] + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class VideoDataPreprocessor(BaseDataPreprocessor): + """Video pre-processor for operations, like normalization and bgr to rgb + conversion . + + Compared with the :class:`mmaction.ActionDataPreprocessor`, this module + supports ``inputs`` as torch.Tensor or a list of torch.Tensor. + + Args: + mean (Sequence[float or int, optional): The pixel mean of channels + of images or stacked optical flow. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation + of channels of images or stacked optical flow. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): Whether to convert image from BGR to RGB. + Defaults to False. + format_shape (str): Format shape of input data. + Defaults to ``'NCHW'``. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + format_shape: str = 'NCHW') -> None: + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.format_shape = format_shape + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + if self.format_shape == 'NCHW': + normalizer_shape = (-1, 1, 1) + elif self.format_shape == 'NCTHW': + normalizer_shape = (-1, 1, 1, 1) + else: + raise ValueError(f'Invalid format shape: {format_shape}') + + self.register_buffer( + 'mean', + torch.tensor(mean, dtype=torch.float32).view(normalizer_shape), + False) + self.register_buffer( + 'std', + torch.tensor(std, dtype=torch.float32).view(normalizer_shape), + False) + else: + self._enable_normalize = False + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Data in the same format + as the model input. + """ + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + if isinstance(batch_inputs, list): + # channel transform + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = [ + _input[..., [2, 1, 0], :, :] for _input in batch_inputs + ] + elif self.format_shape == 'NCTHW': + batch_inputs = [ + _input[..., [2, 1, 0], :, :, :] + for _input in batch_inputs + ] + else: + raise ValueError( + f'Invalid format shape: {self.format_shape}') + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # normalization + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + + else: + # channel transform + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = batch_inputs[..., [2, 1, 0], :, :] + elif self.format_shape == 'NCTHW': + batch_inputs = batch_inputs[..., [2, 1, 0], :, :, :] + else: + raise ValueError( + f'Invalid format shape: {self.format_shape}') + + # convert to float after channel conversion to ensure efficiency + batch_inputs = batch_inputs.float() + + # normalization + if self._enable_normalize: + batch_inputs = (batch_inputs - self.mean) / self.std + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class MultiModalDataPreprocessor(BaseDataPreprocessor): + """Data pre-processor for image-text multimodality tasks. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + ): + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + data = self.cast_data(data) + + imgs = data.get('inputs', None) + + def _process_img(img): + # ------ To RGB ------ + if self.to_rgb and img.size(1) == 3: + img = img.flip(1) + + # -- Normalization --- + img = img.float() + if self._enable_normalize: + img = (img - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = img.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + img = F.pad(img, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + return img + + if isinstance(imgs, torch.Tensor): + imgs = _process_img(imgs) + elif isinstance(imgs, Sequence): + # B, T, C, H, W + imgs = torch.stack([_process_img(img) for img in imgs], dim=1) + elif imgs is not None: + raise ValueError(f'{type(imgs)} is not supported for imgs inputs.') + + data_samples = data.get('data_samples', None) + + return {'images': imgs, 'data_samples': data_samples} diff --git a/mmpretrain/models/utils/ema.py b/mmpretrain/models/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..63c5006bbb0d9ff967b3cce7d3b5ada0cc683468 --- /dev/null +++ b/mmpretrain/models/utils/ema.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import cos, pi +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.logging import MessageHub +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineEMA(ExponentialMovingAverage): + r"""CosineEMA is implemented for updating momentum parameter, used in BYOL, + MoCoV3, etc. + + All parameters are updated by the formula as below: + + .. math:: + + X'_{t+1} = (1 - m) * X'_t + m * X_t + + Where :math:`m` the the momentum parameter. And it's updated with cosine + annealing, including momentum adjustment following: + + .. math:: + m = m_{end} + (m_{end} - m_{start}) * (\cos\frac{k\pi}{K} + 1) / 2 + + where :math:`k` is the current step, :math:`K` is the total steps. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, + :math:`X'_{t}` is the moving average and :math:`X_t` is the new + observed value. The value of momentum is usually a small number, + allowing observed values to slowly update the ema parameters. See also + :external:py:class:`torch.nn.BatchNorm2d`. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The start momentum value. Defaults to 0.004. + end_momentum (float): The end momentum value for cosine annealing. + Defaults to 0. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.004, + end_momentum: float = 0., + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + self.end_momentum = end_momentum + + def avg_func(self, averaged_param: torch.Tensor, + source_param: torch.Tensor, steps: int) -> None: + """Compute the moving average of the parameters using the cosine + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + + Returns: + Tensor: The averaged parameters. + """ + message_hub = MessageHub.get_current_instance() + max_iters = message_hub.get_info('max_iters') + cosine_annealing = (cos(pi * steps / float(max_iters)) + 1) / 2 + momentum = self.end_momentum - (self.end_momentum - + self.momentum) * cosine_annealing + averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) diff --git a/mmpretrain/models/utils/embed.py b/mmpretrain/models/utils/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8299f9a06789768b26ea58260a2984024fbf801d --- /dev/null +++ b/mmpretrain/models/utils/embed.py @@ -0,0 +1,423 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import AdaptivePadding +from mmengine.model import BaseModule + +from .helpers import to_2tuple + + +def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): + """Resize pos_embed weights. + + Args: + pos_embed (torch.Tensor): Position embedding weights with shape + [1, L, C]. + src_shape (tuple): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (tuple): The resolution of downsampled new training + image, in format (H, W). + mode (str): Algorithm used for upsampling. Choose one from 'nearest', + 'linear', 'bilinear', 'bicubic' and 'trilinear'. + Defaults to 'bicubic'. + num_extra_tokens (int): The number of extra tokens, such as cls_token. + Defaults to 1. + + Returns: + torch.Tensor: The resized pos_embed of shape [1, L_new, C] + """ + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, \ + f"The length of `pos_embed` ({L}) doesn't match the expected " \ + f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ + '`img_size` argument.' + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + # The cubic interpolate algorithm only accepts float32 + dst_weight = F.interpolate( + src_weight.float(), size=dst_shape, align_corners=False, mode=mode) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + dst_weight = dst_weight.to(src_weight.dtype) + + return torch.cat((extra_tokens, dst_weight), dim=1) + + +def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head): + """Resize relative position bias table. + + Args: + src_shape (int): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (int): The resolution of downsampled new training + image, in format (H, W). + table (tensor): The relative position bias of the pretrained model. + num_head (int): Number of attention heads. + + Returns: + torch.Tensor: The resized relative position bias table. + """ + from scipy import interpolate + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_shape // 2) + if gp > dst_shape // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src_shape // 2): + dis.append(cur) + cur += q**(i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_shape // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + all_rel_pos_bias = [] + + for i in range(num_head): + z = table[:, i].view(src_shape, src_shape).float().numpy() + f_cubic = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f_cubic(dx, + dy)).contiguous().view(-1, + 1).to(table.device)) + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + return new_rel_pos_bias + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + img_size (int | tuple): The size of input image. Default: 224 + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None + conv_cfg (dict, optional): The config dict for conv layers. + Default: None + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None + """ + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=768, + norm_cfg=None, + conv_cfg=None, + init_cfg=None): + super(PatchEmbed, self).__init__(init_cfg) + warnings.warn('The `PatchEmbed` in mmpretrain will be deprecated. ' + 'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. ' + "It's more general and supports dynamic input shape") + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + self.img_size = img_size + self.embed_dims = embed_dims + + # Use conv layer to embed + conv_cfg = conv_cfg or dict() + _conv_cfg = dict( + type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1) + _conv_cfg.update(conv_cfg) + self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims) + + # Calculate how many patches a input image is splited to. + h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] - + self.projection.dilation[i] * + (self.projection.kernel_size[i] - 1) - 1) // + self.projection.stride[i] + 1 for i in range(2)] + + self.patches_resolution = (h_out, w_out) + self.num_patches = h_out * w_out + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't " \ + f'match model ({self.img_size[0]}*{self.img_size[1]}).' + # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim + x = self.projection(x).flatten(2).transpose(1, 2) + + if self.norm is not None: + x = self.norm(x) + + return x + + +# Modified from pytorch-image-models +class HybridEmbed(BaseModule): + """CNN Feature Map Embedding. + + Extract feature map from CNN, flatten, + project to embedding dim. + + Args: + backbone (nn.Module): CNN backbone + img_size (int | tuple): The size of input image. Default: 224 + feature_size (int | tuple, optional): Size of feature map extracted by + CNN backbone. Default: None + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_cfg (dict, optional): The config dict for conv layers. + Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_channels=3, + embed_dims=768, + conv_cfg=None, + init_cfg=None): + super(HybridEmbed, self).__init__(init_cfg) + assert isinstance(backbone, nn.Module) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of + # determining the exact dim of the output feature + # map for all networks, the feature metadata has + # reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of + # each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_channels, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + # last feature if backbone outputs list/tuple of features + o = o[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + self.num_patches = feature_size[0] * feature_size[1] + + # Use conv layer to embed + conv_cfg = conv_cfg or dict() + _conv_cfg = dict( + type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1) + _conv_cfg.update(conv_cfg) + self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + # last feature if backbone outputs list/tuple of features + x = x[-1] + x = self.projection(x).flatten(2).transpose(1, 2) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + Modified from mmcv, and this module supports specifying whether to use + post-norm. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map ((used in Swin Transformer)). Our + implementation uses :class:`torch.nn.Unfold` to merge patches, which is + about 25% faster than the original implementation. However, we need to + modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. To gets fully covered + by filter and stride you specified. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Defaults to None, which means to be set as + ``kernel_size``. + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Defaults to "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Defaults to 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults to False. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + use_post_norm (bool): Whether to use post normalization here. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + use_post_norm=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.use_post_norm = use_post_norm + + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adaptive_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + if norm_cfg is not None: + # build pre or post norm layer based on different channels + if self.use_post_norm: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + else: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + + if self.adaptive_padding: + x = self.adaptive_padding(x) + H, W = x.shape[-2:] + + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + x = self.sampler(x) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + + if self.use_post_norm: + # use post-norm here + x = self.reduction(x) + x = self.norm(x) if self.norm else x + else: + x = self.norm(x) if self.norm else x + x = self.reduction(x) + + return x, output_size diff --git a/mmpretrain/models/utils/helpers.py b/mmpretrain/models/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..971f45054e5edac15c71aa64ddd26164bf404d22 --- /dev/null +++ b/mmpretrain/models/utils/helpers.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections.abc +import warnings +from itertools import repeat + +import torch +from mmengine.utils import digit_version + + +def is_tracing() -> bool: + """Determine whether the model is called during the tracing of code with + ``torch.jit.trace``.""" + if digit_version(torch.__version__) >= digit_version('1.6.0'): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False + + +# From PyTorch internals +def _ntuple(n): + """A `to_tuple` function generator. + + It returns a function, this function will repeat the input to a tuple of + length ``n`` if the input is not an Iterable object, otherwise, return the + input directly. + + Args: + n (int): The number of the target length. + """ + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/mmpretrain/models/utils/huggingface.py b/mmpretrain/models/utils/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..e527315b26e5d3f34c10d22e75d47b4050de4748 --- /dev/null +++ b/mmpretrain/models/utils/huggingface.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +from typing import Optional + +import transformers +from mmengine.registry import Registry +from transformers import AutoConfig, PreTrainedModel +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from mmpretrain.registry import MODELS, TOKENIZER + + +def register_hf_tokenizer( + cls: Optional[type] = None, + registry: Registry = TOKENIZER, +): + """Register HuggingFace-style PreTrainedTokenizerBase class.""" + if cls is None: + + # use it as a decorator: @register_hf_tokenizer() + def _register(cls): + register_hf_tokenizer(cls=cls) + return cls + + return _register + + def from_pretrained(**kwargs): + if ('pretrained_model_name_or_path' not in kwargs + and 'name_or_path' not in kwargs): + raise TypeError( + f'{cls.__name__}.from_pretrained() missing required ' + "argument 'pretrained_model_name_or_path' or 'name_or_path'.") + # `pretrained_model_name_or_path` is too long for config, + # add an alias name `name_or_path` here. + name_or_path = kwargs.pop('pretrained_model_name_or_path', + kwargs.pop('name_or_path')) + return cls.from_pretrained(name_or_path, **kwargs) + + registry._register_module(module=from_pretrained, module_name=cls.__name__) + return cls + + +_load_hf_pretrained_model = True + + +@contextlib.contextmanager +def no_load_hf_pretrained_model(): + global _load_hf_pretrained_model + _load_hf_pretrained_model = False + yield + _load_hf_pretrained_model = True + + +def register_hf_model( + cls: Optional[type] = None, + registry: Registry = MODELS, +): + """Register HuggingFace-style PreTrainedModel class.""" + if cls is None: + + # use it as a decorator: @register_hf_tokenizer() + def _register(cls): + register_hf_model(cls=cls) + return cls + + return _register + + if issubclass(cls, _BaseAutoModelClass): + get_config = AutoConfig.from_pretrained + from_config = cls.from_config + elif issubclass(cls, PreTrainedModel): + get_config = cls.config_class.from_pretrained + from_config = cls + else: + raise TypeError('Not auto model nor pretrained model of huggingface.') + + def build(**kwargs): + if ('pretrained_model_name_or_path' not in kwargs + and 'name_or_path' not in kwargs): + raise TypeError( + f'{cls.__name__} missing required argument ' + '`pretrained_model_name_or_path` or `name_or_path`.') + # `pretrained_model_name_or_path` is too long for config, + # add an alias name `name_or_path` here. + name_or_path = kwargs.pop('pretrained_model_name_or_path', + kwargs.pop('name_or_path')) + + if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model: + return cls.from_pretrained(name_or_path, **kwargs) + else: + cfg = get_config(name_or_path, **kwargs) + return from_config(cfg) + + registry._register_module(module=build, module_name=cls.__name__) + return cls + + +register_hf_model(transformers.AutoModelForCausalLM) diff --git a/mmpretrain/models/utils/inverted_residual.py b/mmpretrain/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..8387b21251aacff8efcb1b048e37ecdfa1299b2b --- /dev/null +++ b/mmpretrain/models/utils/inverted_residual.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule + +from .se_layer import SELayer + + +class InvertedResidual(BaseModule): + """Inverted Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Defaults to 3. + stride (int): The stride of the depthwise convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_expand_conv = (mid_channels != in_channels) + + if self.with_se: + assert isinstance(se_cfg, dict) + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if self.with_se: + self.se = SELayer(**se_cfg) + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmpretrain/models/utils/layer_scale.py b/mmpretrain/models/utils/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..bb480a15ce35570a5fcfe060c25ef676730430a7 --- /dev/null +++ b/mmpretrain/models/utils/layer_scale.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +import torch.nn as nn + + +class LayerScale(nn.Module): + """LayerScale layer. + + Args: + dim (int): Dimension of input features. + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 1e-5. + inplace (bool): inplace: can optionally do the + operation in-place. Defaults to False. + data_format (str): The input data format, could be 'channels_last' + or 'channels_first', representing (B, C, H, W) and + (B, N, C) format data respectively. Defaults to 'channels_last'. + """ + + def __init__(self, + dim: int, + layer_scale_init_value: Union[float, torch.Tensor] = 1e-5, + inplace: bool = False, + data_format: str = 'channels_last'): + super().__init__() + assert data_format in ('channels_last', 'channels_first'), \ + "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value) + + def forward(self, x): + if self.data_format == 'channels_first': + if self.inplace: + return x.mul_(self.weight.view(-1, 1, 1)) + else: + return x * self.weight.view(-1, 1, 1) + return x.mul_(self.weight) if self.inplace else x * self.weight diff --git a/mmpretrain/models/utils/make_divisible.py b/mmpretrain/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec74689e37d4a9d605a595adb0cca1da88aa19a --- /dev/null +++ b/mmpretrain/models/utils/make_divisible.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmpretrain/models/utils/norm.py b/mmpretrain/models/utils/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8b890a0c6ec654f00e4bb4cd148158eaeba7599d --- /dev/null +++ b/mmpretrain/models/utils/norm.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class GRN(nn.Module): + """Global Response Normalization Module. + + Come from `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked + Autoencoders `_ + + Args: + in_channels (int): The number of channels of the input tensor. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-6. + """ + + def __init__(self, in_channels, eps=1e-6): + super().__init__() + self.in_channels = in_channels + self.gamma = nn.Parameter(torch.zeros(in_channels)) + self.beta = nn.Parameter(torch.zeros(in_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor, data_format='channel_first'): + """Forward method. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + if data_format == 'channel_last': + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + self.eps) + x = self.gamma * (x * nx) + self.beta + x + elif data_format == 'channel_first': + gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps) + x = self.gamma.view(1, -1, 1, 1) * (x * nx) + self.beta.view( + 1, -1, 1, 1) + x + return x + + +@MODELS.register_module('LN2d') +class LayerNorm2d(nn.LayerNorm): + """LayerNorm on channels for 2d images. + + Args: + num_channels (int): The number of channels of the input tensor. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-5. + elementwise_affine (bool): a boolean value that when set to ``True``, + this module has learnable per-element affine parameters initialized + to ones (for weights) and zeros (for biases). Defaults to True. + """ + + def __init__(self, num_channels: int, **kwargs) -> None: + super().__init__(num_channels, **kwargs) + self.num_channels = self.normalized_shape[0] + + def forward(self, x, data_format='channel_first'): + """Forward method. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \ + f'(N, C, H, W), but got tensor with shape {x.shape}' + if data_format == 'channel_last': + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, + self.eps) + elif data_format == 'channel_first': + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, + self.eps) + # If the output is discontiguous, it may cause some unexpected + # problem in the downstream tasks + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +def build_norm_layer(cfg: dict, num_features: int) -> nn.Module: + """Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + + num_features (int): Number of input channels. + + Returns: + nn.Module: The created norm layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + norm_layer = MODELS.get(layer_type) + if norm_layer is None: + raise KeyError(f'Cannot find {layer_type} in registry under scope ' + f'name {MODELS.scope}') + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + else: + layer = norm_layer(num_channels=num_features, **cfg_) + + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return layer diff --git a/mmpretrain/models/utils/position_encoding.py b/mmpretrain/models/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..07a3c486a25a84633d7e50463dd8b09f1c222837 --- /dev/null +++ b/mmpretrain/models/utils/position_encoding.py @@ -0,0 +1,247 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.utils import digit_version + +from ..utils import to_2tuple + +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + feat_token = x + # convert (B, N, C) to (B, C, H, W) + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous() + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +class PositionEncodingFourier(BaseModule): + """The Position Encoding Fourier (PEF) module. + + The PEF is adopted from EdgeNeXt '_. + Args: + in_channels (int): Number of input channels. + Default: 32 + embed_dims (int): The feature dimension. + Default: 768. + temperature (int): Temperature. + Default: 10000. + dtype (torch.dtype): The data type. + Default: torch.float32. + init_cfg (dict): The config dict for initializing the module. + Default: None. + """ + + def __init__(self, + in_channels=32, + embed_dims=768, + temperature=10000, + dtype=torch.float32, + init_cfg=None): + super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1) + self.scale = 2 * math.pi + self.in_channels = in_channels + self.embed_dims = embed_dims + self.dtype = dtype + + if digit_version(torch.__version__) < digit_version('1.8.0'): + floor_div = torch.floor_divide + else: + floor_div = partial(torch.div, rounding_mode='floor') + dim_t = torch.arange(in_channels, dtype=self.dtype) + self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels) + + def forward(self, bhw_shape): + B, H, W = bhw_shape + mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device) + not_mask = ~mask + eps = 1e-6 + y_embed = not_mask.cumsum(1, dtype=self.dtype) + x_embed = not_mask.cumsum(2, dtype=self.dtype) + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = self.dim_t.to(mask.device) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.proj(pos) + + return pos + + +def build_2d_sincos_position_embedding( + patches_resolution: Union[int, Sequence[int]], + embed_dims: int, + temperature: Optional[int] = 10000., + cls_token: Optional[bool] = False) -> torch.Tensor: + """The function is to build position embedding for model to obtain the + position information of the image patches. + + Args: + patches_resolution (Union[int, Sequence[int]]): The resolution of each + patch. + embed_dims (int): The dimension of the embedding vector. + temperature (int, optional): The temperature parameter. Defaults to + 10000. + cls_token (bool, optional): Whether to concatenate class token. + Defaults to False. + + Returns: + torch.Tensor: The position embedding vector. + """ + + if isinstance(patches_resolution, int): + patches_resolution = (patches_resolution, patches_resolution) + + h, w = patches_resolution + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch_meshgrid(grid_w, grid_h) + assert embed_dims % 4 == 0, \ + 'Embed dimension must be divisible by 4.' + pos_dim = embed_dims // 4 + + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature**omega) + out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) + out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) + + pos_emb = torch.cat( + [ + torch.sin(out_w), + torch.cos(out_w), + torch.sin(out_h), + torch.cos(out_h) + ], + dim=1, + )[None, :, :] + + if cls_token: + cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) + pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) + + return pos_emb + + +class RotaryEmbeddingFast(BaseModule): + """Implements 2D rotary embedding (RoPE) for image tokens. Position + encoding is implemented with sin and cos functions, + + .. math:: + Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\ + Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}} + Args: + embed_dims (int): The feature dimension for each head. + patch_resolution (int | tuple): The resolution of the + image, in format (H, W). + theta (float): The hyperparameter for position coding. + Defaults to 10000. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + patch_resolution, + theta=10000., + init_cfg=None): + super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg) + + self.half_dim = embed_dims // 2 + self.patch_resolution = to_2tuple(patch_resolution) + self.theta = theta + + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos) + self.register_buffer('freqs_sin', freqs_sin) + + def compute_position_embedding(self): + frequency = self.theta**( + torch.arange(0, self.half_dim, 2).float() / self.half_dim) + frequency = 1. / frequency + + h, w = self.patch_resolution + th = torch.arange(h) / h * self.half_dim + tw = torch.arange(w) / w * self.half_dim + + position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2) + position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2) + + height = position_h[:, None, :].expand(h, w, self.half_dim) + width = position_w[None, :, :].expand(h, w, self.half_dim) + position = torch.cat((height, width), dim=-1) + + freqs_cos = position.cos().view(-1, position.shape[-1]) + freqs_sin = position.sin().view(-1, position.shape[-1]) + + return freqs_cos, freqs_sin + + def forward(self, x, patch_resolution): + # Check whether the patch resolution is the predefined size + patch_resolution = to_2tuple(patch_resolution) + if patch_resolution != self.patch_resolution: + self.patch_resolution = patch_resolution + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos.to(x.device)) + self.register_buffer('freqs_sin', freqs_sin.to(x.device)) + + batch, num_heads, num_patches, dim = x.shape + + inputs = x + x = x.reshape(batch, num_heads, num_patches, -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + x = x.reshape(batch, num_heads, num_patches, dim) + + return inputs * self.freqs_cos + x * self.freqs_sin diff --git a/mmpretrain/models/utils/res_layer_extra_norm.py b/mmpretrain/models/utils/res_layer_extra_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..37e387ba9795ec528bd210dab75bd05abdc0addf --- /dev/null +++ b/mmpretrain/models/utils/res_layer_extra_norm.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .norm import build_norm_layer + +try: + from mmdet.models.backbones import ResNet + from mmdet.models.roi_heads.shared_heads.res_layer import ResLayer + from mmdet.registry import MODELS + + @MODELS.register_module() + class ResLayerExtraNorm(ResLayer): + """Add extra norm to original ``ResLayer``.""" + + def __init__(self, *args, **kwargs): + super(ResLayerExtraNorm, self).__init__(*args, **kwargs) + + block = ResNet.arch_settings[kwargs['depth']][0] + self.add_module( + 'norm', + build_norm_layer(self.norm_cfg, + 64 * 2**self.stage * block.expansion)) + + def forward(self, x): + """Forward function.""" + res_layer = getattr(self, f'layer{self.stage + 1}') + norm = getattr(self, 'norm') + x = res_layer(x) + out = norm(x) + return out + +except ImportError: + ResLayerExtraNorm = None diff --git a/mmpretrain/models/utils/se_layer.py b/mmpretrain/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..20290171008c2fd6f7a9e14e444f23b8375abe22 --- /dev/null +++ b/mmpretrain/models/utils/se_layer.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.utils import is_tuple_of + +from .make_divisible import make_divisible + + +class SELayer(BaseModule): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + squeeze_channels (None or int): The intermediate channel number of + SElayer. Default: None, means the value of ``squeeze_channels`` + is ``make_divisible(channels // ratio, divisor)``. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will + be ``make_divisible(channels // ratio, divisor)``. Only used when + ``squeeze_channels`` is None. Default: 16. + divisor(int): The divisor to true divide the channel number. Only + used when ``squeeze_channels`` is None. Default: 8. + conv_cfg (None or dict): Config dict for convolution layer. Default: + None, which means using conv2d. + return_weight(bool): Whether to return the weight. Default: False. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='Sigmoid')) + """ + + def __init__(self, + channels, + squeeze_channels=None, + ratio=16, + divisor=8, + bias='auto', + conv_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + return_weight=False, + init_cfg=None): + super(SELayer, self).__init__(init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + if squeeze_channels is None: + squeeze_channels = make_divisible(channels // ratio, divisor) + assert isinstance(squeeze_channels, int) and squeeze_channels > 0, \ + '"squeeze_channels" should be a positive integer, but get ' + \ + f'{squeeze_channels} instead.' + self.return_weight = return_weight + self.conv1 = ConvModule( + in_channels=channels, + out_channels=squeeze_channels, + kernel_size=1, + stride=1, + bias=bias, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=squeeze_channels, + out_channels=channels, + kernel_size=1, + stride=1, + bias=bias, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + if self.return_weight: + return out + else: + return x * out diff --git a/mmpretrain/models/utils/sparse_modules.py b/mmpretrain/models/utils/sparse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6bf345399bbb9c1c2ec4af6c19cfe7adf9beb6 --- /dev/null +++ b/mmpretrain/models/utils/sparse_modules.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved. +# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +class SparseHelper: + """The helper to compute sparse operation with pytorch, such as sparse + convlolution, sparse batch norm, etc.""" + + _cur_active: torch.Tensor = None + + @staticmethod + def _get_active_map_or_index(H: int, + returning_active_map: bool = True + ) -> torch.Tensor: + """Get current active map with (B, 1, f, f) shape or index format.""" + # _cur_active with shape (B, 1, f, f) + downsample_raito = H // SparseHelper._cur_active.shape[-1] + active_ex = SparseHelper._cur_active.repeat_interleave( + downsample_raito, 2).repeat_interleave(downsample_raito, 3) + return active_ex if returning_active_map else active_ex.squeeze( + 1).nonzero(as_tuple=True) + + @staticmethod + def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse convolution forward function.""" + x = super(type(self), self).forward(x) + + # (b, c, h, w) *= (b, 1, h, w), mask the output of conv + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + return x + + @staticmethod + def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse batch norm forward function.""" + active_index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + + # (b, c, h, w) -> (b, h, w, c) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten features + # with shape (n, c) + x_flattened = x_permuted[active_index] + + # use BN1d to normalize this flatten feature (n, c) + x_flattened = super(type(self), self).forward(x_flattened) + + # generate output + output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + output[active_index] = x_flattened + + # (b, h, w, c) -> (b, c, h, w) + output = output.permute(0, 3, 1, 2) + return output + + +class SparseConv2d(nn.Conv2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseMaxPooling(nn.MaxPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseAvgPooling(nn.AvgPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +@MODELS.register_module() +class SparseBatchNorm2d(nn.BatchNorm1d): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module() +class SparseSyncBatchNorm2d(nn.SyncBatchNorm): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module('SparseLN2d') +class SparseLayerNorm2D(nn.LayerNorm): + """Implementation of sparse LayerNorm on channels for 2d images.""" + + def forward(self, + x: torch.Tensor, + data_format='channel_first') -> torch.Tensor: + """Sparse layer norm forward function with 2D data. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, ( + f'LayerNorm2d only supports inputs with shape ' + f'(N, C, H, W), but got tensor with shape {x.shape}') + if data_format == 'channel_last': + index = SparseHelper._get_active_map_or_index( + H=x.shape[1], returning_active_map=False) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x, dtype=x_flattened.dtype) + x[index] = x_flattened + elif data_format == 'channel_first': + index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x_permuted[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + x[index] = x_flattened + x = x.permute(0, 3, 1, 2).contiguous() + else: + raise NotImplementedError + return x diff --git a/mmpretrain/models/utils/swiglu_ffn.py b/mmpretrain/models/utils/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..20b4591f4f09ae185dd28e432dff7919d98d3a50 --- /dev/null +++ b/mmpretrain/models/utils/swiglu_ffn.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.drop import build_dropout + +from .layer_scale import LayerScale +from .norm import build_norm_layer + + +class SwiGLUFFN(nn.Module): + """SwiGLU FFN layer. + + Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0., + bias: bool = True, + dropout_layer: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + add_identity: bool = True, + ) -> None: + super().__init__() + self.embed_dims = embed_dims + self.out_dims = out_dims or embed_dims + hidden_dims = feedforward_channels or embed_dims + + self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, hidden_dims) + else: + self.norm = nn.Identity() + + self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale( + dim=embed_dims, layer_scale_init_value=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + def forward(self, + x: torch.Tensor, + identity: Optional[torch.Tensor] = None) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + hidden = self.norm(hidden) + out = self.w3(hidden) + out = self.gamma2(out) + out = self.dropout_layer(out) + + if self.out_dims != self.embed_dims or not self.add_identity: + # due to the dimension inconsistence or user setting + # not to apply residual operation + return out + + if identity is None: + identity = x + return identity + out + + +class SwiGLUFFNFused(SwiGLUFFN): + """SwiGLU FFN layer with fusing. + + Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0., + bias: bool = True, + ) -> None: + out_dims = out_dims or embed_dims + feedforward_channels = feedforward_channels or embed_dims + feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8 + super().__init__( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + out_dims=out_dims, + layer_scale_init_value=layer_scale_init_value, + bias=bias, + ) diff --git a/mmpretrain/models/utils/tokenizer.py b/mmpretrain/models/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8a324bad00ff03a9ce24dc4cff222e379f1520 --- /dev/null +++ b/mmpretrain/models/utils/tokenizer.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import os + +from mmengine.fileio import list_from_file +from transformers import (AutoTokenizer, BartTokenizer, BasicTokenizer, + BertTokenizer, BertTokenizerFast, LlamaTokenizer, + WordpieceTokenizer) + +from mmpretrain.registry import TOKENIZER +from .huggingface import register_hf_tokenizer + +register_hf_tokenizer(AutoTokenizer) +register_hf_tokenizer(LlamaTokenizer) + + +@register_hf_tokenizer() +class BlipTokenizer(BertTokenizerFast): + """"BlipTokenizer inherit BertTokenizerFast (fast, Rust-based).""" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) + return tokenizer + + +@register_hf_tokenizer() +class Blip2Tokenizer(BertTokenizer): + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + return tokenizer + + +@register_hf_tokenizer() +class OFATokenizer(BartTokenizer): + + vocab_files_names = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt' + } + + pretrained_vocab_files_map = { + 'vocab_file': { + 'OFA-Sys/OFA-tiny': + 'https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/vocab.json', + 'OFA-Sys/OFA-medium': + 'https://huggingface.co/OFA-Sys/OFA-medium/blob/main/vocab.json', + 'OFA-Sys/OFA-base': + 'https://huggingface.co/OFA-Sys/OFA-base/blob/main/vocab.json', + 'OFA-Sys/OFA-large': + 'https://huggingface.co/OFA-Sys/OFA-large/blob/main/vocab.json', + }, + 'merges_file': { + 'OFA-Sys/OFA-tiny': + 'https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/merges.txt', + 'OFA-Sys/OFA-medium': + 'https://huggingface.co/OFA-Sys/OFA-medium/blob/main/merges.txt', + 'OFA-Sys/OFA-base': + 'https://huggingface.co/OFA-Sys/OFA-base/blob/main/merges.txt', + 'OFA-Sys/OFA-large': + 'https://huggingface.co/OFA-Sys/OFA-large/blob/main/merges.txt', + }, + } + + max_model_input_sizes = { + 'OFA-Sys/OFA-tiny': 1024, + 'OFA-Sys/OFA-medium': 1024, + 'OFA-Sys/OFA-base': 1024, + 'OFA-Sys/OFA-large': 1024, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + num_bins = kwargs.pop('num_bins', 1000) + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + length = len(tokenizer) + tokenizer.add_tokens([''.format(i) for i in range(8192)]) + tokenizer.code_offset = length + tokenizer.add_tokens([''.format(i) for i in range(num_bins)]) + tokenizer.bin_offset = length + 8192 + tokenizer.num_bins = num_bins + return tokenizer + + +@TOKENIZER.register_module() +class FullTokenizer(BertTokenizer): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = self.load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token='[UNK]', max_input_chars_per_word=200) + + def load_vocab(self, vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + vocab_list = list_from_file(vocab_file) + for token in vocab_list: + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_by_vocab(self, vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + def convert_tokens_to_ids(self, tokens): + return self.convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return self.convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """Converts a sequence of tokens (string) in a single string.""" + + def clean_up_tokenization(out_string): + """Clean up a list of simple English tokenization artifacts like + spaces before punctuations and abbreviated forms.""" + out_string = ( + out_string.replace(' .', '.').replace(' ?', '?').replace( + ' !', '!').replace(' ,', ',').replace(" ' ", "'").replace( + " n't", "n't").replace(" 'm", "'m").replace( + " 's", "'s").replace(" 've", + "'ve").replace(" 're", "'re")) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) diff --git a/mmpretrain/models/utils/vector_quantizer.py b/mmpretrain/models/utils/vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2ea89339e190d0d19bf5c89b60c1d4bab8fad5 --- /dev/null +++ b/mmpretrain/models/utils/vector_quantizer.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2022 Microsoft +# Modified from +# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mmengine.dist import all_reduce + + +def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average with norm data.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1)) + + +def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor: + """Sample vectors according to the given number.""" + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num, ), device=device) + + return samples[indices] + + +def kmeans(samples: torch.Tensor, + num_clusters: int, + num_iters: int = 10, + use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Run k-means algorithm.""" + dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = F.normalize(new_means, p=2, dim=-1) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + """The codebook of embedding vectors. + + Args: + num_tokens (int): Number of embedding vectors in the codebook. + codebook_dim (int) : The dimension of embedding vectors in the + codebook. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_tokens: int, + codebook_dim: int, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + if codebook_init_path is None: + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = F.normalize(weight, p=2, dim=-1) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f'load init codebook weight from {codebook_init_path}') + codebook_ckpt_weight = torch.load( + codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize embedding vectors of codebook.""" + if self.initted: + return + print('Performing K-means init for codebook') + embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id: torch.Tensor) -> torch.Tensor: + """Get embedding vectors.""" + return F.embedding(embed_id, self.weight) + + +class NormEMAVectorQuantizer(nn.Module): + """Normed EMA vector quantizer module. + + Args: + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + beta (float): The mutiplier for VectorQuantizer embedding loss. + Defaults to 1. + decay (float): The decay parameter of EMA. Defaults to 0.99. + statistic_code_usage (bool): Whether to use cluster_size to record + statistic. Defaults to True. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_embed: int, + embed_dims: int, + beta: float, + decay: float = 0.99, + statistic_code_usage: bool = True, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None) -> None: + super().__init__() + self.codebook_dim = embed_dims + self.num_tokens = num_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA( + num_tokens=self.num_tokens, + codebook_dim=self.codebook_dim, + kmeans_init=kmeans_init, + codebook_init_path=codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(num_embed)) + + def reset_cluster_size(self, device): + + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + """Forward function.""" + # reshape z -> (batch, height, width, channel) + z = rearrange(z, 'b c h w -> b h w c') + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + # 'n d -> d n' + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + all_reduce(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # update cluster size with EMA + bins = encodings.sum(0) + all_reduce(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + all_reduce(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = F.normalize(embed_normalized, p=2, dim=-1) + embed_normalized = torch.where(zero_mask[..., None], + self.embedding.weight, + embed_normalized) + + # Update embedding vectors with EMA + norm_ema_inplace(self.embedding.weight, embed_normalized, + self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, encoding_indices diff --git a/mmpretrain/registry.py b/mmpretrain/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cac2bdad725b9adf5c345d58e5e4a0320b3ddcd4 --- /dev/null +++ b/mmpretrain/registry.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMPretrain provides 21 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +__all__ = [ + 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'LOG_PROCESSORS', + 'OPTIMIZERS', 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', + 'PARAM_SCHEDULERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', + 'MODEL_WRAPPERS', 'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'TASK_UTILS', + 'METRICS', 'EVALUATORS', 'VISUALIZERS', 'VISBACKENDS' +] + +####################################################################### +# mmpretrain.engine # +####################################################################### + +# Runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry( + 'runner', + parent=MMENGINE_RUNNERS, + locations=['mmpretrain.engine'], +) +# Runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', + parent=MMENGINE_RUNNER_CONSTRUCTORS, + locations=['mmpretrain.engine'], +) +# Loops which define the training or test process, like `EpochBasedTrainLoop` +LOOPS = Registry( + 'loop', + parent=MMENGINE_LOOPS, + locations=['mmpretrain.engine'], +) +# Hooks to add additional functions during running, like `CheckpointHook` +HOOKS = Registry( + 'hook', + parent=MMENGINE_HOOKS, + locations=['mmpretrain.engine'], +) +# Log processors to process the scalar log data. +LOG_PROCESSORS = Registry( + 'log processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmpretrain.engine'], +) +# Optimizers to optimize the model weights, like `SGD` and `Adam`. +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmpretrain.engine'], +) +# Optimizer wrappers to enhance the optimization process. +OPTIM_WRAPPERS = Registry( + 'optimizer_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmpretrain.engine'], +) +# Optimizer constructors to customize the hyperparameters of optimizers. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmpretrain.engine'], +) +# Parameter schedulers to dynamically adjust optimization parameters. +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmpretrain.engine'], +) + +####################################################################### +# mmpretrain.datasets # +####################################################################### + +# Datasets like `ImageNet` and `CIFAR10`. +DATASETS = Registry( + 'dataset', + parent=MMENGINE_DATASETS, + locations=['mmpretrain.datasets'], +) +# Samplers to sample the dataset. +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmpretrain.datasets'], +) +# Transforms to process the samples from the dataset. +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmpretrain.datasets'], +) + +####################################################################### +# mmpretrain.models # +####################################################################### + +# Neural network modules inheriting `nn.Module`. +MODELS = Registry( + 'model', + parent=MMENGINE_MODELS, + locations=['mmpretrain.models'], +) +# Model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmpretrain.models'], +) +# Weight initialization methods like uniform, xavier. +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmpretrain.models'], +) +# Batch augmentations like `Mixup` and `CutMix`. +BATCH_AUGMENTS = Registry( + 'batch augment', + locations=['mmpretrain.models'], +) +# Task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', + parent=MMENGINE_TASK_UTILS, + locations=['mmpretrain.models'], +) +# Tokenizer to encode sequence +TOKENIZER = Registry( + 'tokenizer', + locations=['mmpretrain.models'], +) + +####################################################################### +# mmpretrain.evaluation # +####################################################################### + +# Metrics to evaluate the model prediction results. +METRICS = Registry( + 'metric', + parent=MMENGINE_METRICS, + locations=['mmpretrain.evaluation'], +) +# Evaluators to define the evaluation process. +EVALUATORS = Registry( + 'evaluator', + parent=MMENGINE_EVALUATOR, + locations=['mmpretrain.evaluation'], +) + +####################################################################### +# mmpretrain.visualization # +####################################################################### + +# Visualizers to display task-specific results. +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmpretrain.visualization'], +) +# Backends to save the visualization results, like TensorBoard, WandB. +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmpretrain.visualization'], +) diff --git a/mmpretrain/structures/__init__.py b/mmpretrain/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7de863087d9d07800ff119d3c8b941059ef3886 --- /dev/null +++ b/mmpretrain/structures/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_sample import DataSample +from .multi_task_data_sample import MultiTaskDataSample +from .utils import (batch_label_to_onehot, cat_batch_labels, format_label, + format_score, label_to_onehot, tensor_split) + +__all__ = [ + 'DataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'tensor_split', + 'MultiTaskDataSample', 'label_to_onehot', 'format_label', 'format_score' +] diff --git a/mmpretrain/structures/__pycache__/__init__.cpython-38.pyc b/mmpretrain/structures/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4344b392d945b99431b1627e3eb892547e95e39a Binary files /dev/null and b/mmpretrain/structures/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/structures/__pycache__/data_sample.cpython-38.pyc b/mmpretrain/structures/__pycache__/data_sample.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0479b384a3be56abae9230aa8d8adc69fbc06c6b Binary files /dev/null and b/mmpretrain/structures/__pycache__/data_sample.cpython-38.pyc differ diff --git a/mmpretrain/structures/__pycache__/multi_task_data_sample.cpython-38.pyc b/mmpretrain/structures/__pycache__/multi_task_data_sample.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccff557f8ad00658285db86e9f3ee2eb888f4ff7 Binary files /dev/null and b/mmpretrain/structures/__pycache__/multi_task_data_sample.cpython-38.pyc differ diff --git a/mmpretrain/structures/__pycache__/utils.cpython-38.pyc b/mmpretrain/structures/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99880a3e4c8cce6fdff1da7b12ed1787cb8f8344 Binary files /dev/null and b/mmpretrain/structures/__pycache__/utils.cpython-38.pyc differ diff --git a/mmpretrain/structures/data_sample.py b/mmpretrain/structures/data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..ce588b8ba13811afdb2bb3300d42f221a6f2df7f --- /dev/null +++ b/mmpretrain/structures/data_sample.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing.reduction import ForkingPickler +from typing import Union + +import numpy as np +import torch +from mmengine.structures import BaseDataElement + +from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score + + +class DataSample(BaseDataElement): + """A general data structure interface. + + It's used as the interface between different components. + + The following fields are convention names in MMPretrain, and we will set or + get these fields in data transforms, models, and metrics if needed. You can + also set any new fields for your need. + + Meta fields: + img_shape (Tuple): The shape of the corresponding input image. + ori_shape (Tuple): The original shape of the corresponding image. + sample_idx (int): The index of the sample in the dataset. + num_classes (int): The number of all categories. + + Data fields: + gt_label (tensor): The ground truth label. + gt_score (tensor): The ground truth score. + pred_label (tensor): The predicted label. + pred_score (tensor): The predicted score. + mask (tensor): The mask used in masked image modeling. + + Examples: + >>> import torch + >>> from mmpretrain.structures import DataSample + >>> + >>> img_meta = dict(img_shape=(960, 720), num_classes=5) + >>> data_sample = DataSample(metainfo=img_meta) + >>> data_sample.set_gt_label(3) + >>> print(data_sample) + + >>> + >>> # For multi-label data + >>> data_sample = DataSample().set_gt_label([0, 1, 4]) + >>> print(data_sample) + + >>> + >>> # Set one-hot format score + >>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1]) + >>> print(data_sample) + + >>> + >>> # Set custom field + >>> data_sample = DataSample() + >>> data_sample.my_field = [1, 2, 3] + >>> print(data_sample) + + >>> print(data_sample.my_field) + [1, 2, 3] + """ + + def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample': + """Set ``gt_label``.""" + self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor) + return self + + def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample': + """Set ``gt_score``.""" + score = format_score(value) + self.set_field(score, 'gt_score', dtype=torch.Tensor) + if hasattr(self, 'num_classes'): + assert len(score) == self.num_classes, \ + f'The length of score {len(score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', value=len(score), field_type='metainfo') + return self + + def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample': + """Set ``pred_label``.""" + self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor) + return self + + def set_pred_score(self, value: SCORE_TYPE): + """Set ``pred_label``.""" + score = format_score(value) + self.set_field(score, 'pred_score', dtype=torch.Tensor) + if hasattr(self, 'num_classes'): + assert len(score) == self.num_classes, \ + f'The length of score {len(score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', value=len(score), field_type='metainfo') + return self + + def set_mask(self, value: Union[torch.Tensor, np.ndarray]): + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Invalid mask type {type(value)}') + self.set_field(value, 'mask', dtype=torch.Tensor) + return self + + def __repr__(self) -> str: + """Represent the object.""" + + def dump_items(items, prefix=''): + return '\n'.join(f'{prefix}{k}: {v}' for k, v in items) + + repr_ = '' + if len(self._metainfo_fields) > 0: + repr_ += '\n\nMETA INFORMATION\n' + repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4) + if len(self._data_fields) > 0: + repr_ += '\n\nDATA FIELDS\n' + repr_ += dump_items(self.items(), prefix=' ' * 4) + + repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>' + return repr_ + + +def _reduce_datasample(data_sample): + """reduce DataSample.""" + attr_dict = data_sample.__dict__ + convert_keys = [] + for k, v in attr_dict.items(): + if isinstance(v, torch.Tensor): + attr_dict[k] = v.numpy() + convert_keys.append(k) + return _rebuild_datasample, (attr_dict, convert_keys) + + +def _rebuild_datasample(attr_dict, convert_keys): + """rebuild DataSample.""" + data_sample = DataSample() + for k in convert_keys: + attr_dict[k] = torch.from_numpy(attr_dict[k]) + data_sample.__dict__ = attr_dict + return data_sample + + +# Due to the multi-processing strategy of PyTorch, DataSample may consume many +# file descriptors because it contains multiple tensors. Here we overwrite the +# reduce function of DataSample in ForkingPickler and convert these tensors to +# np.ndarray during pickling. It may slightly influence the performance of +# dataloader. +ForkingPickler.register(DataSample, _reduce_datasample) diff --git a/mmpretrain/structures/multi_task_data_sample.py b/mmpretrain/structures/multi_task_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f00993861bfb4f35fb7d145198f81c5e9f0a5993 --- /dev/null +++ b/mmpretrain/structures/multi_task_data_sample.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.structures import BaseDataElement + + +class MultiTaskDataSample(BaseDataElement): + + @property + def tasks(self): + return self._data_fields diff --git a/mmpretrain/structures/utils.py b/mmpretrain/structures/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f9e95ef6fd557b9d0bdf5f017a7b73ba250453 --- /dev/null +++ b/mmpretrain/structures/utils.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.utils import is_str + +if hasattr(torch, 'tensor_split'): + tensor_split = torch.tensor_split +else: + # A simple implementation of `tensor_split`. + def tensor_split(input: torch.Tensor, indices: list): + outs = [] + for start, end in zip([0] + indices, indices + [input.size(0)]): + outs.append(input[start:end]) + return outs + + +LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int] +SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence] + + +def format_label(value: LABEL_TYPE) -> torch.Tensor: + """Convert various python types to label-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + + Returns: + :obj:`torch.Tensor`: The foramtted label tensor. + """ + + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).to(torch.long) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def format_score(value: SCORE_TYPE) -> torch.Tensor: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence): Score values. + + Returns: + :obj:`torch.Tensor`: The foramtted score tensor. + """ + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).float() + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def cat_batch_labels(elements: List[torch.Tensor]): + """Concat a batch of label tensor to one tensor. + + Args: + elements (List[tensor]): A batch of labels. + + Returns: + Tuple[torch.Tensor, List[int]]: The first item is the concated label + tensor, and the second item is the split indices of every sample. + """ + labels = [] + splits = [0] + for element in elements: + labels.append(element) + splits.append(splits[-1] + element.size(0)) + batch_label = torch.cat(labels) + return batch_label, splits[1:-1] + + +def batch_label_to_onehot(batch_label, split_indices, num_classes): + """Convert a concated label tensor to onehot format. + + Args: + batch_label (torch.Tensor): A concated label tensor from multiple + samples. + split_indices (List[int]): The split indices of every sample. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmpretrain.structures import batch_label_to_onehot + >>> # Assume a concated label from 3 samples. + >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] + >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) + >>> split_indices = [2, 5] + >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) + tensor([[1, 1, 0, 0, 0], + [1, 0, 1, 0, 1], + [0, 1, 0, 1, 0]]) + """ + sparse_onehot_list = F.one_hot(batch_label, num_classes) + onehot_list = [ + sparse_onehot.sum(0) + for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) + ] + return torch.stack(onehot_list) + + +def label_to_onehot(label: LABEL_TYPE, num_classes: int): + """Convert a label to onehot format tensor. + + Args: + label (LABEL_TYPE): Label value. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmpretrain.structures import label_to_onehot + >>> # Single-label + >>> label_to_onehot(1, num_classes=5) + tensor([0, 1, 0, 0, 0]) + >>> # Multi-label + >>> label_to_onehot([0, 2, 3], num_classes=5) + tensor([1, 0, 1, 1, 0]) + """ + label = format_label(label) + sparse_onehot = F.one_hot(label, num_classes) + return sparse_onehot.sum(0) diff --git a/mmpretrain/utils/__init__.py b/mmpretrain/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..991e3217d2f1e5926028e6c9c79e450e30404a33 --- /dev/null +++ b/mmpretrain/utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .analyze import load_json_log +from .collect_env import collect_env +from .dependency import require +from .misc import get_ori_model +from .progress import track, track_on_main_process +from .setup_env import register_all_modules + +__all__ = [ + 'collect_env', 'register_all_modules', 'track_on_main_process', + 'load_json_log', 'get_ori_model', 'track', 'require' +] diff --git a/mmpretrain/utils/__pycache__/__init__.cpython-38.pyc b/mmpretrain/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8a127b231c6ab9b266c43bb72671c36c639f59b Binary files /dev/null and b/mmpretrain/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/mmpretrain/utils/__pycache__/analyze.cpython-38.pyc b/mmpretrain/utils/__pycache__/analyze.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee4db9f3da760a7e8c33c21c346b40f7ddbe2d42 Binary files /dev/null and b/mmpretrain/utils/__pycache__/analyze.cpython-38.pyc differ diff --git a/mmpretrain/utils/__pycache__/collect_env.cpython-38.pyc b/mmpretrain/utils/__pycache__/collect_env.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..876d17463525b4dc4251a58d9090e0d5e7a89237 Binary files /dev/null and b/mmpretrain/utils/__pycache__/collect_env.cpython-38.pyc differ diff --git a/mmpretrain/utils/__pycache__/dependency.cpython-38.pyc b/mmpretrain/utils/__pycache__/dependency.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec4d06396b44dc2b3e95696d4da31164819b614 Binary files /dev/null and b/mmpretrain/utils/__pycache__/dependency.cpython-38.pyc differ diff --git a/mmpretrain/utils/__pycache__/misc.cpython-38.pyc b/mmpretrain/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f498768f3d5133c8d04325c9f43098d12a607f25 Binary files /dev/null and b/mmpretrain/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/mmpretrain/utils/__pycache__/progress.cpython-38.pyc b/mmpretrain/utils/__pycache__/progress.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76bd48fd98bdc0c6af206d2816a50e67e8b131c3 Binary files /dev/null and b/mmpretrain/utils/__pycache__/progress.cpython-38.pyc differ diff --git a/mmpretrain/utils/__pycache__/setup_env.cpython-38.pyc b/mmpretrain/utils/__pycache__/setup_env.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9435dfd9a37bfe4867975ee4cc57fd0c02486125 Binary files /dev/null and b/mmpretrain/utils/__pycache__/setup_env.cpython-38.pyc differ diff --git a/mmpretrain/utils/analyze.py b/mmpretrain/utils/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..a933591618951e1e49558f4f5cbbdf9c49a76bfe --- /dev/null +++ b/mmpretrain/utils/analyze.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + + +def load_json_log(json_log): + """load and convert json_logs to log_dicts. + + Args: + json_log (str): The path of the json log file. + + Returns: + dict: The result dict contains two items, "train" and "val", for + the training log and validate log. + + Example: + An example output: + + .. code-block:: python + + { + 'train': [ + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 100}, + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 200}, + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 300}, + ... + ] + 'val': [ + {"accuracy/top1": 32.1, "step": 1}, + {"accuracy/top1": 50.2, "step": 2}, + {"accuracy/top1": 60.3, "step": 2}, + ... + ] + } + """ + log_dict = dict(train=[], val=[]) + with open(json_log, 'r') as log_file: + for line in log_file: + log = json.loads(line.strip()) + # A hack trick to determine whether the line is training log. + mode = 'train' if 'lr' in log else 'val' + log_dict[mode].append(log) + + return log_dict diff --git a/mmpretrain/utils/collect_env.py b/mmpretrain/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..988451ec530e8d21ec3d5a087a3bb7f7b66fd223 --- /dev/null +++ b/mmpretrain/utils/collect_env.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmpretrain + + +def collect_env(with_torch_comiling_info=False): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMCV'] = mmcv.__version__ + if not with_torch_comiling_info: + env_info.pop('PyTorch compiling details') + env_info['MMPreTrain'] = mmpretrain.__version__ + '+' + get_git_hash()[:7] + return env_info diff --git a/mmpretrain/utils/dependency.py b/mmpretrain/utils/dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3d8ae5df7a6968f26e0563e80a7d37a2e2cd68 --- /dev/null +++ b/mmpretrain/utils/dependency.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from functools import wraps +from inspect import isfunction + +from importlib_metadata import PackageNotFoundError, distribution +from mmengine.utils import digit_version + + +def satisfy_requirement(dep): + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, dep, maxsplit=1) + parts = [p.strip() for p in parts] + package = parts[0] + if len(parts) > 1: + op, version = parts[1:] + op = { + '>=': '__ge__', + '==': '__eq__', + '>': '__gt__', + '<': '__lt__', + '<=': '__le__' + }[op] + else: + op, version = None, None + + try: + dist = distribution(package) + if op is None or getattr(digit_version(dist.version), op)( + digit_version(version)): + return True + except PackageNotFoundError: + pass + + return False + + +def require(dep, install=None): + """A wrapper of function for extra package requirements. + + Args: + dep (str): The dependency package name, like ``transformers`` + or ``transformers>=4.28.0``. + install (str, optional): The installation command hint. Defaults + to None, which means to use "pip install dep". + """ + + def wrapper(fn): + assert isfunction(fn) + + @wraps(fn) + def ask_install(*args, **kwargs): + name = fn.__qualname__.replace('.__init__', '') + ins = install or f'pip install "{dep}"' + raise ImportError( + f'{name} requires {dep}, please install it by `{ins}`.') + + if satisfy_requirement(dep): + fn._verify_require = getattr(fn, '_verify_require', lambda: None) + return fn + + ask_install._verify_require = ask_install + return ask_install + + return wrapper + + +WITH_MULTIMODAL = all( + satisfy_requirement(item) + for item in ['pycocotools', 'transformers>=4.28.0']) + + +def register_multimodal_placeholder(names, registry): + for name in names: + + def ask_install(*args, **kwargs): + raise ImportError( + f'{name} requires extra multi-modal dependencies, please ' + 'install it by `pip install "mmpretrain[multimodal]"` ' + 'or `pip install -e ".[multimodal]"`.') + + registry.register_module(name=name, module=ask_install) diff --git a/mmpretrain/utils/misc.py b/mmpretrain/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cc532679943689233be76e9a8f74da8ed822443e --- /dev/null +++ b/mmpretrain/utils/misc.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmengine.model import is_model_wrapper + + +def get_ori_model(model: nn.Module) -> nn.Module: + """Get original model if the input model is a model wrapper. + + Args: + model (nn.Module): A model may be a model wrapper. + + Returns: + nn.Module: The model without model wrapper. + """ + if is_model_wrapper(model): + return model.module + else: + return model diff --git a/mmpretrain/utils/progress.py b/mmpretrain/utils/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..b23f976a42fc3a2f6e38f025f01041deb5608405 --- /dev/null +++ b/mmpretrain/utils/progress.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import mmengine.dist as dist +import rich.progress as progress +from rich.live import Live + +disable_progress_bar = False +global_progress = progress.Progress( + '{task.description}', + progress.BarColumn(), + progress.TaskProgressColumn(show_speed=True), + progress.TimeRemainingColumn(), +) +global_live = Live(global_progress, refresh_per_second=10) + + +def track(sequence, description: str = '', total: Optional[float] = None): + if disable_progress_bar: + yield from sequence + else: + global_live.start() + task_id = global_progress.add_task(description, total=total) + task = global_progress._tasks[task_id] + try: + yield from global_progress.track(sequence, task_id=task_id) + finally: + if task.total is None: + global_progress.update(task_id, total=task.completed) + if all(task.finished for task in global_progress.tasks): + global_live.stop() + for task_id in global_progress.task_ids: + global_progress.remove_task(task_id) + + +def track_on_main_process(sequence, description='', total=None): + if not dist.is_main_process() or disable_progress_bar: + yield from sequence + else: + yield from track(sequence, total=total, description=description) diff --git a/mmpretrain/utils/setup_env.py b/mmpretrain/utils/setup_env.py new file mode 100644 index 0000000000000000000000000000000000000000..1b57b848c98a75c7a1b5854c800ecc2dd5da6df8 --- /dev/null +++ b/mmpretrain/utils/setup_env.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings + +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmpretrain into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmpretrain default + scope. If True, the global default scope will be set to + `mmpretrain`, and all registries will build modules from + mmpretrain's registry node. To understand more about the registry, + please refer to + https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa: E501 + import mmpretrain.datasets # noqa: F401,F403 + import mmpretrain.engine # noqa: F401,F403 + import mmpretrain.evaluation # noqa: F401,F403 + import mmpretrain.models # noqa: F401,F403 + import mmpretrain.structures # noqa: F401,F403 + import mmpretrain.visualization # noqa: F401,F403 + + if not init_default_scope: + return + + current_scope = DefaultScope.get_current_instance() + if current_scope is None: + DefaultScope.get_instance('mmpretrain', scope_name='mmpretrain') + elif current_scope.scope_name != 'mmpretrain': + warnings.warn( + f'The current default scope "{current_scope.scope_name}" ' + 'is not "mmpretrain", `register_all_modules` will force ' + 'the current default scope to be "mmpretrain". If this is ' + 'not expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmpretrain-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmpretrain') diff --git a/mmpretrain/version.py b/mmpretrain/version.py new file mode 100644 index 0000000000000000000000000000000000000000..6a60b40f31da1d6681d70010af3556b3b2363e5d --- /dev/null +++ b/mmpretrain/version.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved + +__version__ = '1.0.0' + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/mmpretrain/visualization/__init__.py b/mmpretrain/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbeecfb070193f479b248dca3e98311577410a1 --- /dev/null +++ b/mmpretrain/visualization/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .utils import create_figure, get_adaptive_scale +from .visualizer import UniversalVisualizer + +__all__ = ['UniversalVisualizer', 'get_adaptive_scale', 'create_figure'] diff --git a/mmpretrain/visualization/utils.py b/mmpretrain/visualization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91a1d81f1449dfbfb7ff5198eb6dc25a6386ed48 --- /dev/null +++ b/mmpretrain/visualization/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from matplotlib.figure import Figure + + +def get_adaptive_scale(img_shape: Tuple[int, int], + min_scale: float = 0.3, + max_scale: float = 3.0) -> float: + """Get adaptive scale according to image shape. + + The target scale depends on the the short edge length of the image. If the + short edge length equals 224, the output is 1.0. And output linear scales + according the short edge length. + + You can also specify the minimum scale and the maximum scale to limit the + linear scale. + + Args: + img_shape (Tuple[int, int]): The shape of the canvas image. + min_size (int): The minimum scale. Defaults to 0.3. + max_size (int): The maximum scale. Defaults to 3.0. + + Returns: + int: The adaptive scale. + """ + short_edge_length = min(img_shape) + scale = short_edge_length / 224. + return min(max(scale, min_scale), max_scale) + + +def create_figure(*args, margin=False, **kwargs) -> 'Figure': + """Create a independent figure. + + Different from the :func:`plt.figure`, the figure from this function won't + be managed by matplotlib. And it has + :obj:`matplotlib.backends.backend_agg.FigureCanvasAgg`, and therefore, you + can use the ``canvas`` attribute to get access the drawn image. + + Args: + *args: All positional arguments of :class:`matplotlib.figure.Figure`. + margin: Whether to reserve the white edges of the figure. + Defaults to False. + **kwargs: All keyword arguments of :class:`matplotlib.figure.Figure`. + + Return: + matplotlib.figure.Figure: The created figure. + """ + from matplotlib.backends.backend_agg import FigureCanvasAgg + from matplotlib.figure import Figure + + figure = Figure(*args, **kwargs) + FigureCanvasAgg(figure) + + if not margin: + # remove white edges by set subplot margin + figure.subplots_adjust(left=0, right=1, bottom=0, top=1) + + return figure diff --git a/mmpretrain/visualization/visualizer.py b/mmpretrain/visualization/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d18ca87f6bc246b4defe17281ae87c4464e1b89 --- /dev/null +++ b/mmpretrain/visualization/visualizer.py @@ -0,0 +1,777 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.dataset import BaseDataset +from mmengine.dist import master_only +from mmengine.visualization import Visualizer +from mmengine.visualization.utils import img_from_canvas + +from mmpretrain.registry import VISUALIZERS +from mmpretrain.structures import DataSample +from .utils import create_figure, get_adaptive_scale + + +@VISUALIZERS.register_module() +class UniversalVisualizer(Visualizer): + """Universal Visualizer for multiple tasks. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. + """ + DEFAULT_TEXT_CFG = { + 'family': 'monospace', + 'color': 'white', + 'bbox': dict(facecolor='black', alpha=0.5, boxstyle='Round'), + 'verticalalignment': 'top', + 'horizontalalignment': 'left', + } + + @master_only + def visualize_cls(self, + image: np.ndarray, + data_sample: DataSample, + classes: Optional[Sequence[str]] = None, + draw_gt: bool = True, + draw_pred: bool = True, + draw_score: bool = True, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize image classification result. + + This method will draw an text box on the input image to visualize the + information about image classification, like the ground-truth label and + prediction label. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + classes (Sequence[str], optional): The categories names. + Defaults to None. + draw_gt (bool): Whether to draw ground-truth labels. + Defaults to True. + draw_pred (bool): Whether to draw prediction labels. + Defaults to True. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if self.dataset_meta is not None: + classes = classes or self.dataset_meta.get('classes', None) + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + elif rescale_factor is not None: + image = mmcv.imrescale(image, rescale_factor) + + texts = [] + self.set_image(image) + + if draw_gt and 'gt_label' in data_sample: + idx = data_sample.gt_label.tolist() + class_labels = [''] * len(idx) + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] + prefix = 'Ground truth: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + if draw_pred and 'pred_label' in data_sample: + idx = data_sample.pred_label.tolist() + score_labels = [''] * len(idx) + class_labels = [''] * len(idx) + if draw_score and 'pred_score' in data_sample: + score_labels = [ + f', {data_sample.pred_score[i].item():.2f}' for i in idx + ] + + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + + labels = [ + str(idx[i]) + score_labels[i] + class_labels[i] + for i in range(len(idx)) + ] + prefix = 'Prediction: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image retrieval result. + + This method will draw the input image and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + if resize is not None: + image = mmcv.imrescale(image, (resize, resize)) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True) + gs = figure.add_gridspec(2, topk) + query_plot = figure.add_subplot(gs[0, :]) + query_plot.axis(False) + query_plot.imshow(image) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[1, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + def add_mask_to_image( + self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + ) -> np.ndarray: + if isinstance(resize, int): + resize = (resize, resize) + + image = mmcv.imresize(image, resize) + self.set_image(image) + + if isinstance(data_sample.mask, np.ndarray): + data_sample.mask = torch.tensor(data_sample.mask) + mask = data_sample.mask.float()[None, None, ...] + mask_ = F.interpolate(mask, image.shape[:2], mode='nearest')[0, 0] + + self.draw_binary_masks(mask_.bool(), colors=color, alphas=alpha) + + drawn_img = self.get_image() + return drawn_img + + @master_only + def visualize_masked_image(self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize masked image. + + This method will draw an image with binary mask. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int | Tuple[int]): Resize the input image to the specified + shape. Defaults to 224. + color (str | Tuple[int]): The color of the binary mask. + Defaults to "black". + alpha (int | float): The transparency of the mask. Defaults to 0.8. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + drawn_img = self.add_mask_to_image( + image=image, + data_sample=data_sample, + resize=resize, + color=color, + alpha=alpha) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_caption(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image caption result. + + This method will draw the input image and the images caption. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + data_sample.get('pred_caption'), + wrap=True, + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_vqa(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize visual question answering result. + + This method will draw the input image, question and answer. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + text = (f'Q: {data_sample.get("question")}\n' + f'A: {data_sample.get("pred_answer")}') + self.ax_save.text( + img_scale * 5, + img_scale * 5, + text, + wrap=True, + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_visual_grounding(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + line_width: Union[int, float] = 3, + bbox_color: Union[str, tuple] = 'green', + step: int = 0) -> None: + """Visualize visual grounding result. + + This method will draw the input image, bbox and the object. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + gt_bboxes = data_sample.get('gt_bboxes') + pred_bboxes = data_sample.get('pred_bboxes') + if resize is not None: + h, w = image.shape[:2] + if w < h: + image, w_scale, h_scale = mmcv.imresize( + image, (resize, resize * h // w), return_scale=True) + else: + image, w_scale, h_scale = mmcv.imresize( + image, (resize * w // h, resize), return_scale=True) + pred_bboxes[:, ::2] *= w_scale + pred_bboxes[:, 1::2] *= h_scale + if gt_bboxes is not None: + gt_bboxes[:, ::2] *= w_scale + gt_bboxes[:, 1::2] *= h_scale + + self.set_image(image) + # Avoid the line-width limit in the base classes. + self._default_font_size = 1e3 + self.draw_bboxes( + pred_bboxes, line_widths=line_width, edge_colors=bbox_color) + if gt_bboxes is not None: + self.draw_bboxes( + gt_bboxes, line_widths=line_width, edge_colors='blue') + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + + text_positions = pred_bboxes[:, :2] + line_width + for i in range(pred_bboxes.size(0)): + self.ax_save.text( + text_positions[i, 0] + line_width, + text_positions[i, 1] + line_width, + data_sample.get('text'), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_t2i_retrieval(self, + text: str, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + text_cfg: dict = dict(), + fig_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize Text-To-Image retrieval result. + + This method will draw the input text and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + topk (int): To visualize the topk matching items. Defaults to 1. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + fig_cfg (dict): Extra figure setting, which accepts arguments of + :func:`plt.Figure`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True, **fig_cfg) + figure.suptitle(text) + gs = figure.add_gridspec(1, topk) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[0, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_i2t_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: Sequence[str], + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize Image-To-Text retrieval result. + + This method will draw the input image and the texts retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (Sequence[str]): The prototype dataset. + It should be a list of texts. + topk (int): To visualize the topk matching items. Defaults to 1. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + texts = [] + for score, sample_idx in zip(match_scores, indices): + text = prototype_dataset[sample_idx.item()] + if draw_score: + text = f'{score:.2f} ' + text + texts.append(text) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img diff --git a/models/__pycache__/embedder.cpython-38.pyc b/models/__pycache__/embedder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf2ff78848511322700822a6951baf73a0daf7a3 Binary files /dev/null and b/models/__pycache__/embedder.cpython-38.pyc differ diff --git a/models/__pycache__/model_utils.cpython-38.pyc b/models/__pycache__/model_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2c7fa8f21f28be981b183235415aa4be77732d8 Binary files /dev/null and b/models/__pycache__/model_utils.cpython-38.pyc differ diff --git a/models/cldm_v15.yaml b/models/cldm_v15.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fde1825577acd46dc90d8d7c6730e22be762fccb --- /dev/null +++ b/models/cldm_v15.yaml @@ -0,0 +1,79 @@ +model: + target: cldm.cldm.ControlLDM + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + control_key: "hint" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + only_mid_control: False + + control_stage_config: + target: cldm.cldm.ControlNet + params: + image_size: 32 # unused + in_channels: 4 + hint_channels: 3 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + unet_config: + target: cldm.cldm.ControlledUnetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/models/cldm_v21.yaml b/models/cldm_v21.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc65193647e476e108fce5977f11250d55919106 --- /dev/null +++ b/models/cldm_v21.yaml @@ -0,0 +1,85 @@ +model: + target: cldm.cldm.ControlLDM + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + control_key: "hint" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + only_mid_control: False + + control_stage_config: + target: cldm.cldm.ControlNet + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + hint_channels: 3 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + unet_config: + target: cldm.cldm.ControlledUnetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/models/control_sd15_openpose.pth b/models/control_sd15_openpose.pth new file mode 100644 index 0000000000000000000000000000000000000000..8bece72dd7e4ab7308e5c931b305e3895397b9d2 --- /dev/null +++ b/models/control_sd15_openpose.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d19ffffeeaff6d9feb2204b234c3e1b9aec039ab3e63fca07f4fe5646f2ef591 +size 5710751843 diff --git a/models/embedder.py b/models/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..65e7498d0172d50205cbc1e8192375c6e83f5b38 --- /dev/null +++ b/models/embedder.py @@ -0,0 +1,50 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Optional, Tuple, Union +from .model_utils import zero_module + +class Embedder(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 64, 128), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding \ No newline at end of file diff --git a/models/model_utils.py b/models/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d07f1787f29f5ff08ab42743000ae9bb6226e76 --- /dev/null +++ b/models/model_utils.py @@ -0,0 +1,8 @@ +import torch +from torch import nn +from torch.nn import functional as F + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module \ No newline at end of file diff --git a/online.log b/online.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openclip/open_clip/__init__.py b/openclip/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..088c86441ec71a241320de79b7b66a6afeb3a049 --- /dev/null +++ b/openclip/open_clip/__init__.py @@ -0,0 +1,13 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg diff --git a/openclip/open_clip/__pycache__/__init__.cpython-38.pyc b/openclip/open_clip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5790b7198def36087c6b79a7607c0ea81d865fce Binary files /dev/null and b/openclip/open_clip/__pycache__/__init__.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/coca_model.cpython-38.pyc b/openclip/open_clip/__pycache__/coca_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c1b7bd40607566154c30651233ab9360b4e79a0 Binary files /dev/null and b/openclip/open_clip/__pycache__/coca_model.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/constants.cpython-38.pyc b/openclip/open_clip/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..280f6d545b1814c115a18619b8d6dd6046d23539 Binary files /dev/null and b/openclip/open_clip/__pycache__/constants.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/factory.cpython-38.pyc b/openclip/open_clip/__pycache__/factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f0d3d28eddf85db7eea08ea6f63aa9cc803a0a8 Binary files /dev/null and b/openclip/open_clip/__pycache__/factory.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/hf_configs.cpython-38.pyc b/openclip/open_clip/__pycache__/hf_configs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9578616159edf40ddc69465d885122e4b27b1d19 Binary files /dev/null and b/openclip/open_clip/__pycache__/hf_configs.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/hf_model.cpython-38.pyc b/openclip/open_clip/__pycache__/hf_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4013e79c2d3a998597e88386a16d662a2064d2d3 Binary files /dev/null and b/openclip/open_clip/__pycache__/hf_model.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/loss.cpython-38.pyc b/openclip/open_clip/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59e36591370482705fff247da745784461edf600 Binary files /dev/null and b/openclip/open_clip/__pycache__/loss.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/model.cpython-38.pyc b/openclip/open_clip/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c3f1480424fea10d62cb2a5127b45d4a2f5c59a Binary files /dev/null and b/openclip/open_clip/__pycache__/model.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/modified_resnet.cpython-38.pyc b/openclip/open_clip/__pycache__/modified_resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f593f890ddd951598c26615132bc1da0b404c1d Binary files /dev/null and b/openclip/open_clip/__pycache__/modified_resnet.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/openai.cpython-38.pyc b/openclip/open_clip/__pycache__/openai.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c27df13634802d2a7451432152f8a2327c086352 Binary files /dev/null and b/openclip/open_clip/__pycache__/openai.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/pretrained.cpython-38.pyc b/openclip/open_clip/__pycache__/pretrained.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6b62b6d304b17a860ce4413ff51d2fe8d4268f5 Binary files /dev/null and b/openclip/open_clip/__pycache__/pretrained.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc b/openclip/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bd90f6ef13773a5740a9398ba14a884818c2d25 Binary files /dev/null and b/openclip/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/timm_model.cpython-38.pyc b/openclip/open_clip/__pycache__/timm_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e9979034f9e7b6352f2fb66933b4b94668421e9 Binary files /dev/null and b/openclip/open_clip/__pycache__/timm_model.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/tokenizer.cpython-38.pyc b/openclip/open_clip/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b4ee01579a0814c57ae9476ec63e1279d0d838 Binary files /dev/null and b/openclip/open_clip/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/transform.cpython-38.pyc b/openclip/open_clip/__pycache__/transform.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43141ceff3adc1c722f4088b9992a22d1bf27228 Binary files /dev/null and b/openclip/open_clip/__pycache__/transform.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/transformer.cpython-38.pyc b/openclip/open_clip/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd04ea9e1190c9ea755759e9d4523a46f59338af Binary files /dev/null and b/openclip/open_clip/__pycache__/transformer.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/utils.cpython-38.pyc b/openclip/open_clip/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d536f6fa8572238bcf080568c7d3a3197194710a Binary files /dev/null and b/openclip/open_clip/__pycache__/utils.cpython-38.pyc differ diff --git a/openclip/open_clip/__pycache__/version.cpython-38.pyc b/openclip/open_clip/__pycache__/version.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc227e8ffb080957a13c01cb36aee5f1a114e91f Binary files /dev/null and b/openclip/open_clip/__pycache__/version.cpython-38.pyc differ diff --git a/openclip/open_clip/bpe_simple_vocab_16e6.txt.gz b/openclip/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/openclip/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/openclip/open_clip/coca_model.py b/openclip/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/openclip/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/openclip/open_clip/constants.py b/openclip/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/openclip/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/openclip/open_clip/factory.py b/openclip/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..14011f9340fc6e54876c3c5bcb9e23a8cd57849d --- /dev/null +++ b/openclip/open_clip/factory.py @@ -0,0 +1,366 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + if pretrained_image: + if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + model.to(device=device) + if precision in ("fp16", "bf16"): + convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/openclip/open_clip/generation_utils.py b/openclip/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openclip/open_clip/hf_configs.py b/openclip/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..e236222bafce0358445ea16953ca0b2d5a84758a --- /dev/null +++ b/openclip/open_clip/hf_configs.py @@ -0,0 +1,45 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, +} diff --git a/openclip/open_clip/hf_model.py b/openclip/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fbccc812757bf10b122ff14096980e0e38d1d221 --- /dev/null +++ b/openclip/open_clip/hf_model.py @@ -0,0 +1,176 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/openclip/open_clip/model.py b/openclip/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5e775a74641642ea0451a0d6142ccb2c6594eb --- /dev/null +++ b/openclip/open_clip/model.py @@ -0,0 +1,445 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + output_tokens: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/openclip/open_clip/model_configs/RN101-quickgelu.json b/openclip/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/openclip/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/RN101.json b/openclip/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/openclip/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/RN50-quickgelu.json b/openclip/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/openclip/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/openclip/open_clip/model_configs/RN50.json b/openclip/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/openclip/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/RN50x16.json b/openclip/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/openclip/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/RN50x4.json b/openclip/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/openclip/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/RN50x64.json b/openclip/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/openclip/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-B-16-plus-240.json b/openclip/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-B-16-plus.json b/openclip/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-B-16.json b/openclip/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-B-32-plus-256.json b/openclip/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-B-32-quickgelu.json b/openclip/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-B-32.json b/openclip/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-H-14.json b/openclip/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-H-16.json b/openclip/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-L-14-280.json b/openclip/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-L-14-336.json b/openclip/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-L-14.json b/openclip/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-L-16-320.json b/openclip/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-L-16.json b/openclip/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-M-16-alt.json b/openclip/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-M-16.json b/openclip/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-M-32-alt.json b/openclip/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-M-32.json b/openclip/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-S-16-alt.json b/openclip/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-S-16.json b/openclip/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-S-32-alt.json b/openclip/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-S-32.json b/openclip/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-bigG-14.json b/openclip/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-e-14.json b/openclip/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/ViT-g-14.json b/openclip/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/openclip/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/coca_ViT-B-32.json b/openclip/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/openclip/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/coca_ViT-L-14.json b/openclip/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/openclip/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/openclip/open_clip/model_configs/coca_base.json b/openclip/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/openclip/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/coca_roberta-ViT-B-32.json b/openclip/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/openclip/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/openclip/open_clip/model_configs/convnext_base.json b/openclip/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_base_w.json b/openclip/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_base_w_320.json b/openclip/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_large.json b/openclip/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_large_d.json b/openclip/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_large_d_320.json b/openclip/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_small.json b/openclip/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_tiny.json b/openclip/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_xlarge.json b/openclip/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_xxlarge.json b/openclip/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/convnext_xxlarge_320.json b/openclip/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/openclip/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/mt5-base-ViT-B-32.json b/openclip/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/openclip/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/openclip/open_clip/model_configs/mt5-xl-ViT-H-14.json b/openclip/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/openclip/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/openclip/open_clip/model_configs/roberta-ViT-B-32.json b/openclip/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/openclip/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/openclip/open_clip/model_configs/swin_base_patch4_window7_224.json b/openclip/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/openclip/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/vit_medium_patch16_gap_256.json b/openclip/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/openclip/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/openclip/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/openclip/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/openclip/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/openclip/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/openclip/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/openclip/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/openclip/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/openclip/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/openclip/open_clip/modified_resnet.py b/openclip/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c0b033a80e7d08a20a367050c5b1bc5d5292e7 --- /dev/null +++ b/openclip/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from open_clip.utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/openclip/open_clip/openai.py b/openclip/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356 --- /dev/null +++ b/openclip/open_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': + model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/openclip/open_clip/pretrained.py b/openclip/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..87e7e527497d643fdf6ac931ac73b6e887a90d0d --- /dev/null +++ b/openclip/open_clip/pretrained.py @@ -0,0 +1,376 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + # laion400m_32k=_pcfg( + # url="", + # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # laion400m_64k=_pcfg( + # url="", + # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/openclip/open_clip/push_to_hf_hub.py b/openclip/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..23c0631c81dcb43829b7374fac09406ecefcb436 --- /dev/null +++ b/openclip/open_clip/push_to_hf_hub.py @@ -0,0 +1,243 @@ +import argparse +import json +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + weights_filename='open_clip_pytorch_model.bin', + config_filename='open_clip_config.json', +): + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / weights_filename + torch.save(model.state_dict(), weights_path) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" + readme_text += "library_tag: open_clip\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/openclip/open_clip/timm_model.py b/openclip/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dc71a693f9a42ec01fd88d307661bc382b4d05bc --- /dev/null +++ b/openclip/open_clip/timm_model.py @@ -0,0 +1,127 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if pool in ('abs_attn', 'rot_attn'): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, 'projection layer needed if non-attention pooling is used.' + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/openclip/open_clip/tokenizer.py b/openclip/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/openclip/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/openclip/open_clip/transform.py b/openclip/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/openclip/open_clip/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/openclip/open_clip/transformer.py b/openclip/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0151017a6e12396034927c99174bce87d31f13 --- /dev/null +++ b/openclip/open_clip/transformer.py @@ -0,0 +1,724 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/openclip/open_clip/utils.py b/openclip/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673 --- /dev/null +++ b/openclip/open_clip/utils.py @@ -0,0 +1,60 @@ +from itertools import repeat +import collections.abc + +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) diff --git a/openclip/open_clip/version.py b/openclip/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..48aa744fb053599044caf0253b889b5cfe5b78e7 --- /dev/null +++ b/openclip/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.16.0' diff --git a/openclip/open_clip_torch.egg-info/PKG-INFO b/openclip/open_clip_torch.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..1c394983121588c0a6d5ecc394e6c450b965f293 --- /dev/null +++ b/openclip/open_clip_torch.egg-info/PKG-INFO @@ -0,0 +1,800 @@ +Metadata-Version: 2.1 +Name: open-clip-torch +Version: 2.16.0 +Summary: OpenCLIP +Home-page: https://github.com/mlfoundations/open_clip +Author: +Author-email: +Keywords: CLIP pretrained +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Topic :: Scientific/Engineering +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Classifier: Topic :: Software Development +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +Provides-Extra: training +License-File: LICENSE + +# OpenCLIP + +[[Paper]](https://arxiv.org/abs/2212.07143) [[Clip Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb) [[Coca Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_coca.ipynb) +[![pypi](https://img.shields.io/pypi/v/open_clip_torch.svg)](https://pypi.python.org/pypi/open_clip_torch) + +Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training). + +The goal of this repository is to enable training models with contrastive image-text supervision, and to investigate their properties such as robustness to distribution shift. Our starting point is an implementation of CLIP that matches the accuracy of the original CLIP models when trained on the same dataset. +Specifically, a ResNet-50 model trained with our codebase on OpenAI's [15 million image subset of YFCC](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md) achieves **32.7%** top-1 accuracy on ImageNet. OpenAI's CLIP model reaches **31.3%** when trained on the same subset of YFCC. For ease of experimentation, we also provide code for training on the 3 million images in the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/download) dataset, where a ResNet-50x4 trained with our codebase reaches 22.2% top-1 ImageNet accuracy. + +We further this with a replication study on a dataset of comparable size to OpenAI's, [LAION-400M](https://arxiv.org/abs/2111.02114), and with the larger [LAION-2B](https://laion.ai/blog/laion-5b/) superset. In addition, we study scaling behavior in a paper on [reproducible scaling laws for contrastive language-image learning](https://arxiv.org/abs/2212.07143). + +We have trained the following ViT CLIP models: + * ViT-B/32 on LAION-400M with a accuracy of **62.9%**, comparable to OpenAI's **63.2%**, zero-shot top-1 on ImageNet-1k + * ViT-B/32 on LAION-2B with a accuracy of **66.6%**. + * ViT-B/16 on LAION-400M achieving an accuracy of **67.1%**, lower than OpenAI's **68.3%** (as measured here, 68.6% in paper) + * ViT-B/16+ 240x240 (~50% more FLOPS than B/16 224x224) on LAION-400M achieving an accuracy of **69.2%** + * ViT-B/16 on LAION-2B with a accuracy of **70.2%**. + * ViT-L/14 on LAION-400M with an accuracy of **72.77%**, vs OpenAI's **75.5%** (as measured here, 75.3% in paper) + * ViT-L/14 on LAION-2B with an accuracy of **75.3%**, vs OpenAI's **75.5%** (as measured here, 75.3% in paper) + * CoCa ViT-L/14 on LAION-2B with an accuracy of **75.5%** (currently only 13B samples seen) vs. CLIP ViT-L/14 73.1% (on the same dataset and samples seen) + * ViT-H/14 on LAION-2B with an accuracy of **78.0**. The second best in1k zero-shot for released, open-source weights thus far. + * ViT-g/14 on LAION-2B with an accuracy of **76.6**. This was trained on reduced 12B samples seen schedule, same samples seen as 400M models. + * ViT-g/14 on LAION-2B with an accuracy of **78.5**. Full 34B samples seen schedule. + * ViT-G/14 on LAION-2B with an accuracy of **80.1**. The best in1k zero-shot for released, open-source weights thus far. + +And the following ConvNeXt CLIP models: + * ConvNext-Base @ 224x224 on LAION-400M with an ImageNet-1k zero-shot top-1 of **66.3%** + * ConvNext-Base (W) @ 256x256 on LAION-2B with an ImageNet-1k zero-shot top-1 of **70.8%** + * ConvNext-Base (W) @ 256x256 /w augreg (extra augmentation + regularization) on LAION-2B with a top-1 of **71.5%** + * ConvNext-Base (W) @ 256x256 on LAION-A (900M sample aesthetic subset of 2B) with a top-1 of **71.0%** + * ConvNext-Base (W) @ 320x320 on LAION-A with a top-1 of **71.7%** (eval at 384x384 is **71.0**) + * ConvNext-Base (W) @ 320x320 /w augreg on LAION-A with a top-1 of **71.3%** (eval at 384x384 is **72.2%**) + * ConvNext-Large (D) @ 256x256 /w augreg on LAION-2B with a top-1 of **75.9%** + * ConvNext-Large (D) @ 320x320 fine-tune of 256x256 weights above for ~2.5B more samples on LAION-2B, top-1 of **76.6%** + * ConvNext-Large (D) @ 320x320 soup of 3 fine-tunes of 256x256 weights above on LAION-2B, top-1 of **76.9%** + * ConvNext-XXLarge @ 256x256 original run **79.1%** + * ConvNext-XXLarge @ 256x256 rewind of last 10% **79.3%** + * ConvNext-XXLarge @ 256x256 soup of original + rewind **79.4%** + +Model cards w/ additional model specific details can be found on the Hugging Face Hub under the OpenCLIP library tag: https://huggingface.co/models?library=open_clip + +As we describe in more detail [below](#why-are-low-accuracy-clip-models-interesting), CLIP models in a medium accuracy regime already allow us to draw conclusions about the robustness of larger CLIP models since the models follow [reliable scaling laws](https://arxiv.org/abs/2107.04649). + +This codebase is work in progress, and we invite all to contribute in making it more accessible and useful. In the future, we plan to add support for TPU training and release larger models. We hope this codebase facilitates and promotes further research in contrastive image-text learning. Please submit an issue or send an email if you have any other requests or suggestions. + +Note that portions of `src/open_clip/` modelling and tokenizer code are adaptations of OpenAI's official [repository](https://github.com/openai/CLIP). + +## Approach + +| ![CLIP](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/CLIP.png) | +|:--:| +| Image Credit: https://github.com/openai/CLIP | + +## Usage + +``` +pip install open_clip_torch +``` + +```python +import torch +from PIL import Image +import open_clip + +model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') +tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu') + +image = preprocess(Image.open("CLIP.png")).unsqueeze(0) +text = tokenizer(["a diagram", "a dog", "a cat"]) + +with torch.no_grad(), torch.cuda.amp.autocast(): + image_features = model.encode_image(image) + text_features = model.encode_text(text) + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + + text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) + +print("Label probs:", text_probs) # prints: [[1., 0., 0.]] +``` +See also this [[Clip Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb) + +To compute billions of embeddings efficiently, you can use [clip-retrieval](https://github.com/rom1504/clip-retrieval) which has openclip support. + +## Fine-tuning on classification tasks + +This repository is focused on training CLIP models. To fine-tune a *trained* zero-shot model on a downstream classification task such as ImageNet, please see [our other repository: WiSE-FT](https://github.com/mlfoundations/wise-ft). The [WiSE-FT repository](https://github.com/mlfoundations/wise-ft) contains code for our paper on [Robust Fine-tuning of Zero-shot Models](https://arxiv.org/abs/2109.01903), in which we introduce a technique for fine-tuning zero-shot models while preserving robustness under distribution shift. + +## Data + +To download datasets as webdataset, we recommend [img2dataset](https://github.com/rom1504/img2dataset) + +### Conceptual Captions + +See [cc3m img2dataset example](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) + +### YFCC and other datasets + +In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1,000 data points each, which we create using [tarp](https://github.com/webdataset/tarp). + +You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/). +Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers. +The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md). + + +## Training CLIP + +### Install + +We advise you first create a virtual environment with: + +``` +python3 -m venv .env +source .env/bin/activate +pip install -U pip +``` + +You can then install openclip for training with `pip install 'open_clip_torch[training]'`. + +#### Development + +If you want to make changes to contribute code, you can close openclip then run `make install` in openclip folder (after creating a virtualenv) + +Install pip PyTorch as per https://pytorch.org/get-started/locally/ + +You may run `make install-training` to install training deps + +#### Testing + +Test can be run with `make install-test` then `make test` + +`python -m pytest -x -s -v tests -k "training"` to run a specific test + +Running regression tests against a specific git revision or tag: +1. Generate testing data + ```sh + python tests/util_test.py --model RN50 RN101 --save_model_list models.txt --git_revision 9d31b2ec4df6d8228f370ff20c8267ec6ba39383 + ``` + **_WARNING_: This will invoke git and modify your working tree, but will reset it to the current state after data has been generated! \ + Don't modify your working tree while test data is being generated this way.** + +2. Run regression tests + ```sh + OPEN_CLIP_TEST_REG_MODELS=models.txt python -m pytest -x -s -v -m regression_test + ``` + +### Sample single-process running code: + +```bash +python -m training.main \ + --save-frequency 1 \ + --zeroshot-frequency 1 \ + --report-to tensorboard \ + --train-data="/path/to/train_data.csv" \ + --val-data="/path/to/validation_data.csv" \ + --csv-img-key filepath \ + --csv-caption-key title \ + --imagenet-val=/path/to/imagenet/root/val/ \ + --warmup 10000 \ + --batch-size=128 \ + --lr=1e-3 \ + --wd=0.1 \ + --epochs=30 \ + --workers=8 \ + --model RN50 +``` + +Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set! +You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it doest not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh). + +### Multi-GPU and Beyond + +This code has been battle tested up to 1024 A100s and offers a variety of solutions +for distributed training. We include native support for SLURM clusters. + +As the number of devices used to train increases, so does the space complexity of +the the logit matrix. Using a naïve all-gather scheme, space complexity will be +`O(n^2)`. Instead, complexity may become effectively linear if the flags +`--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one +numerical results as the naïve method. + +#### Epochs + +For larger datasets (eg Laion2B), we recommend setting --train-num-samples to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with --dataset-resampled to do sampling with replacement. This allows having frequent checkpoints to evaluate more often. + +#### Patch Dropout + +Recent research has shown that one can dropout half to three-quarters of the visual tokens, leading to up to 2-3x training speeds without loss of accuracy. + +You can set this on your visual transformer config with the key `patch_dropout`. + +In the paper, they also finetuned without the patch dropout at the end. You can do this with the command-line argument `--force-patch-dropout 0.` + +#### Multiple data sources + +OpenCLIP supports using multiple data sources, by separating different data paths with `::`. +For instance, to train on CC12M and on LAION, one might use `--train-data '/data/cc12m/cc12m-train-{0000..2175}.tar'::/data/LAION-400M/{00000..41455}.tar"`. +Using `--dataset-resampled` is recommended for these cases. + +By default, on expectation the amount of times the model will see a sample from each source is proportional to the size of the source. +For instance, when training on one data source with size 400M and one with size 10M, samples from the first source are 40x more likely to be seen in expectation. + +We also support different weighting of the data sources, by using the `--train-data-upsampling-factors` flag. +For instance, using `--train-data-upsampling-factors=1::1` in the above scenario is equivalent to not using the flag, and `--train-data-upsampling-factors=1::2` is equivalent to upsampling the second data source twice. +If you want to sample from data sources with the same frequency, the upsampling factors should be inversely proportional to the sizes of the data sources. +For instance, if dataset `A` has 1000 samples and dataset `B` has 100 samples, you can use `--train-data-upsampling-factors=0.001::0.01` (or analogously, `--train-data-upsampling-factors=1::10`). + +#### Single-Node + +We make use of `torchrun` to launch distributed jobs. The following launches a +a job on a node of 4 GPUs: + +```bash +cd open_clip/src +torchrun --nproc_per_node 4 -m training.main \ + --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \ + --train-num-samples 10968539 \ + --dataset-type webdataset \ + --batch-size 320 \ + --precision amp \ + --workers 4 \ + --imagenet-val /data/imagenet/validation/ +``` + +#### Multi-Node + +The same script above works, so long as users include information about the number +of nodes and host node. + +```bash +cd open_clip/src +torchrun --nproc_per_node=4 \ + --rdzv_endpoint=$HOSTE_NODE_ADDR \ + -m training.main \ + --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \ + --train-num-samples 10968539 \ + --dataset-type webdataset \ + --batch-size 320 \ + --precision amp \ + --workers 4 \ + --imagenet-val /data/imagenet/validation/ +``` + +#### SLURM + +This is likely the easiest solution to utilize. The following script was used to +train our largest models: + +```bash +#!/bin/bash -x +#SBATCH --nodes=32 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=6 +#SBATCH --wait-all-nodes=1 +#SBATCH --job-name=open_clip +#SBATCH --account=ACCOUNT_NAME +#SBATCH --partition PARTITION_NAME + +eval "$(/path/to/conda/bin/conda shell.bash hook)" # init conda +conda activate open_clip +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export MASTER_PORT=12802 + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=$master_addr + +cd /shared/open_clip +export PYTHONPATH="$PYTHONPATH:$PWD/src" +srun --cpu_bind=v --accel-bind=gn python -u src/training/main.py \ + --save-frequency 1 \ + --report-to tensorboard \ + --train-data="/data/LAION-400M/{00000..41455}.tar" \ + --warmup 2000 \ + --batch-size=256 \ + --epochs=32 \ + --workers=8 \ + --model ViT-B-32 \ + --name "ViT-B-32-Vanilla" \ + --seed 0 \ + --local-loss \ + --gather-with-grad +``` + +### Resuming from a checkpoint: + +```bash +python -m training.main \ + --train-data="/path/to/train_data.csv" \ + --val-data="/path/to/validation_data.csv" \ + --resume /path/to/checkpoints/epoch_K.pt +``` + +### Training CoCa: +Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config: +```json +"multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "latent_dim": 512, + "attn_pooler_heads": 8 +} +``` +Credit to [lucidrains](https://github.com/lucidrains) for [initial code](https://github.com/lucidrains/CoCa-pytorch), [gpucce](https://github.com/gpucce) for adapting the code to open_clip, and [iejMac](https://github.com/iejMac) for training the models. + +### Generating text with CoCa + +```python +import open_clip +import torch +from PIL import Image + +model, _, transform = open_clip.create_model_and_transforms( + model_name="coca_ViT-L-14", + pretrained="mscoco_finetuned_laion2B-s13B-b90k" +) + +im = Image.open("cat.jpg").convert("RGB") +im = transform(im).unsqueeze(0) + +with torch.no_grad(), torch.cuda.amp.autocast(): + generated = model.generate(im) + +print(open_clip.decode(generated[0]).split("")[0].replace("", "")) +``` + +See also this [[Coca Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_coca.ipynb) + +### Fine Tuning CoCa + +To fine-tune coca on mscoco, first create the dataset, one way is using a csvdataset and perhaps the simplest way to do it is using [CLIP_benchmark](https://github.com/LAION-AI/CLIP_benchmark) which in turn uses [pycocotools](https://github.com/cocodataset/cocoapi) (that can be used also by itself). + +```python +from clip_benchmark.datasets.builder import build_dataset +import pandas as pd +import os + +root_path = "path/to/data/dir" # set this to smth meaningful +ds = build_dataset("mscoco_captions", root=root_path, split="train") # this downloads the dataset if it is not there already +coco = ds.coco +imgs = coco.loadImgs(coco.getImgIds()) +future_df = {"filepath":[], "title":[]} +for img in imgs: + caps = coco.imgToAnns[img["id"]] + for cap in caps: + future_df["filepath"].append(img["file_name"]) + future_df["title"].append(cap["caption"]) +pd.DataFrame.from_dict(future_df).to_csv( + os.path.join(root_path, "train2014.csv"), index=False, sep="\t" +) +``` +This should create a csv dataset that one can use to fine-tune coca with open_clip +```bash +python -m training.main \ + --dataset-type "csv" \ + --train-data "path/to/data/dir/train2014.csv" \ + --warmup 1000 \ + --batch-size 128 \ + --lr 1e-5 \ + --wd 0.1 \ + --epochs 1 \ + --workers 3 \ + --model "coca_ViT-L-14" \ + --report-to "wandb" \ + --coca-contrastive-loss-weight 0 \ + --coca-caption-loss-weight 1 \ + --log-every-n-steps 100 +``` + +This is a general setting, open_clip has very parameters that can be set, ```python -m training.main --help``` should show them. The only relevant change compared to pre-training are the two arguments + +```bash +--coca-contrastive-loss-weight 0 +--coca-caption-loss-weight 1 +``` +which make the model only train the generative side. + +### Training with pre-trained language models as text encoder: + +If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen: +```bash +python -m training.main \ + --train-data="pipe:aws s3 cp s3://s-mas/cc3m/{00000..00329}.tar -" \ + --train-num-samples 3000000 \ + --val-data="pipe:aws s3 cp s3://s-mas/cc3m/{00330..00331}.tar -" \ + --val-num-samples 10000 \ + --dataset-type webdataset \ + --batch-size 256 \ + --warmup 2000 \ + --epochs 10 \ + --lr 5e-4 \ + --precision amp \ + --workers 6 \ + --model "roberta-ViT-B-32" \ + --lock-text \ + --lock-text-unlocked-layers 10 \ + --name "10_unfrozen" \ + --report-to "tensorboard" \ +``` + +### Loss Curves + +When run on a machine with 8 GPUs the command should produce the following training curve for Conceptual Captions: + +![CLIP zero shot training curve](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/clip_zeroshot.png) + +More detailed curves for Conceptual Captions are given at [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md). + +When training a RN50 on YFCC the same hyperparameters as above are used, with the exception of `lr=5e-4` and `epochs=32`. + +Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`. + +### Launch tensorboard: +```bash +tensorboard --logdir=logs/tensorboard/ --port=7777 +``` + +## Evaluation / Zero-Shot + +We recommend https://github.com/LAION-AI/CLIP_benchmark#how-to-use for systematic evaluation on 40 datasets. + +### Evaluating local checkpoint: + +```bash +python -m training.main \ + --val-data="/path/to/validation_data.csv" \ + --model RN101 \ + --pretrained /path/to/checkpoints/epoch_K.pt +``` + +### Evaluating hosted pretrained checkpoint on ImageNet zero-shot prediction: + +```bash +python -m training.main \ + --imagenet-val /path/to/imagenet/validation \ + --model ViT-B-32-quickgelu \ + --pretrained laion400m_e32 +``` + +## Pretrained model details + +### LAION-400M - https://laion.ai/laion-400-open-dataset + +We are working on reproducing OpenAI's ViT results with the comparably sized (and open) LAION-400M dataset. Trained +weights may be found in release [v0.2](https://github.com/mlfoundations/open_clip/releases/tag/v0.2-weights). + +The LAION400M weights have been trained on the JUWELS supercomputer (see acknowledgements section below). + +#### ViT-B/32 224x224 + +We replicate OpenAI's results on ViT-B/32, reaching a top-1 ImageNet-1k zero-shot accuracy of 62.96%. + + + +__Zero-shot comparison (courtesy of Andreas Fürst)__ + + +ViT-B/32 was trained with 128 A100 (40 GB) GPUs for ~36 hours, 4600 GPU-hours. The per-GPU batch size was 256 for a global batch size of 32768. 256 is much lower than it could have been (~320-384) due to being sized initially before moving to 'local' contrastive loss. + +#### ViT-B/16 224x224 + +The B/16 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 67.07. + + + +This was the first major train session using the updated webdataset 0.2.x code. A bug was found that prevented shards from being shuffled properly between nodes/workers each epoch. This was fixed part way through training (epoch 26) but likely had an impact. + +ViT-B/16 was trained with 176 A100 (40 GB) GPUS for ~61 hours, 10700 GPU-hours. Batch size per GPU was 192 for a global batch size of 33792. + +#### ViT-B/16+ 240x240 + +The B/16+ 240x240 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 69.21. + +This model is the same depth as the B/16, but increases the + * vision width from 768 -> 896 + * text width from 512 -> 640 + * the resolution 224x224 -> 240x240 (196 -> 225 tokens) + + + +Unlike the B/16 run above, this model was a clean run with no dataset shuffling issues. + +ViT-B/16+ was trained with 224 A100 (40 GB) GPUS for ~61 hours, 13620 GPU-hours. Batch size per GPU was 160 for a global batch size of 35840. + +#### ViT-L/14 224x224 + +The L/14 LAION-400M training reached a top-1 ImageNet-1k zero-shot validation score of 72.77. + + + +ViT-L/14 was trained with 400 A100 (40 GB) GPUS for ~127 hours, 50800 GPU-hours. Batch size per GPU was 96 for a global batch size of 38400. Grad checkpointing was enabled. + +### LAION-2B (en) - https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/ + +A ~2B sample subset of LAION-5B with english captions (https://huggingface.co/datasets/laion/laion2B-en) + +#### ViT-B/32 224x224 +A ViT-B/32 trained on LAION-2B, reaching a top-1 ImageNet-1k zero-shot accuracy of 65.62%. + + + +ViT-B/32 was trained with 112 A100 (40 GB) GPUs. The per-GPU batch size was 416 for a global batch size of 46592. Compute generously provided by [stability.ai](https://stability.ai/). + +A second iteration of B/32 was trained on stability.ai cluster with a larger global batch size and learning rate, hitting 66.6% top-1. See https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K + +#### ViT-L/14 224x224 + +A ViT-L/14 with a 75.3% top-1 ImageNet-1k zero-shot was trained on JUWELS Booster. See model details here https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K + +These weights use a different dataset mean and std than others. Instead of using the OpenAI mean & std, inception style normalization `[-1, 1]` is used via a mean and std of `[0.5, 0.5, 0.5]`. This is handled automatically if using `open_clip.create_model_and_transforms` from pretrained weights. + +#### ViT-H/14 224x224 + +A ViT-H/14 with a 78.0% top-1 ImageNet-1k zero-shot was trained on JUWELS Booster. See model details here https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + +#### ViT-g/14 224x224 + +A ViT-g/14 with a 76.6% top-1 ImageNet-1k zero-shot was trained on JUWELS Booster. See model details here https://huggingface.co/laion/CLIP-ViT-g-14-laion2B-s12B-b42K + +This model was trained with a shorted schedule than other LAION-2B models with 12B samples seen instead of 32+B. It matches LAION-400M training in samples seen. Many zero-shot results are lower as a result, but despite this it performs very well in some OOD zero-shot and retrieval tasks. + + +#### ViT-B/32 roberta base + +A ViT-B/32 with roberta base encoder with a 61.7% top-1 ImageNet-1k zero-shot was trained on stability. See model details here https://huggingface.co/laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k +This is the first openclip model using a HF text tower. It has better performance on a range of tasks compared to the standard text encoder, see [metrics](https://huggingface.co/laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/blob/main/unknown.png) + +#### ViT-B/32 xlm roberta base + +A ViT-B/32 with xlm roberta base encoder with a 62.33% top-1 ImageNet-1k zero-shot was trained on stability. See model details here https://huggingface.co/laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k +This is the first openclip model trained on the full laion5B dataset; hence the first multilingual clip trained with openclip. It has better performance on a range of tasks compared to the standard text encoder, see [metrics](https://huggingface.co/laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/blob/main/metrics.png) +A preliminary multilingual evaluation was run: 43% on imagenet1k italian (vs 21% for english B/32), 37% for imagenet1k japanese (vs 1% for english B/32 and 50% for B/16 clip japanese). It shows the multilingual property is indeed there as expected. Larger models will get even better performance. + +#### ViT-H/14 xlm roberta large + +A ViT-H/14 with xlm roberta large encoder with a 77.0% (vs 78% for the english equivalent) top-1 ImageNet-1k zero-shot was trained on stability. See model details here https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k + +This model was trained following the [LiT](https://arxiv.org/abs/2111.07991) methodology: the image tower was frozen (initialized from english openclip ViT-H/14), the text tower was initialized from [xlm roberta large](https://huggingface.co/xlm-roberta-large) and unfrozen. This reduced training cost by a 3x factor. + +See full english [metrics](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/resolve/main/results_xlm_roberta_large.png) + +On zero shot classification on imagenet with translated prompts this model reaches: + +* 56% in italian (vs 21% for https://github.com/clip-italian/clip-italian) +* 53% in japanese (vs 54.6% for https://github.com/rinnakk/japanese-clip) +* 55.7% in chinese (to be compared with https://github.com/OFA-Sys/Chinese-CLIP) + + +#### YFCC-15M + +Below are checkpoints of models trained on YFCC-15M, along with their zero-shot top-1 accuracies on ImageNet and ImageNetV2. These models were trained using 8 GPUs and the same hyperparameters described in the "Sample running code" section, with the exception of `lr=5e-4` and `epochs=32`. + +* [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt) (32.7% / 27.9%) +* [ResNet-101](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt) (34.8% / 30.0%) + +#### CC12M - https://github.com/google-research-datasets/conceptual-12m + +* [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt) (36.45%) + +### Pretrained Model Interface + +We offer a simple model interface to instantiate both pre-trained and untrained models. + +NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient than native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs. + +Future trained models will use nn.GELU. + +```python +>>> import open_clip +>>> open_clip.list_pretrained() +[('RN50', 'openai'), + ('RN50', 'yfcc15m'), + ('RN50', 'cc12m'), + ('RN50-quickgelu', 'openai'), + ('RN50-quickgelu', 'yfcc15m'), + ('RN50-quickgelu', 'cc12m'), + ('RN101', 'openai'), + ('RN101', 'yfcc15m'), + ('RN101-quickgelu', 'openai'), + ('RN101-quickgelu', 'yfcc15m'), + ('RN50x4', 'openai'), + ('RN50x16', 'openai'), + ('RN50x64', 'openai'), + ('ViT-B-32', 'openai'), + ('ViT-B-32', 'laion400m_e31'), + ('ViT-B-32', 'laion400m_e32'), + ('ViT-B-32', 'laion2b_e16'), + ('ViT-B-32', 'laion2b_s34b_b79k'), + ('ViT-B-32-quickgelu', 'openai'), + ('ViT-B-32-quickgelu', 'laion400m_e31'), + ('ViT-B-32-quickgelu', 'laion400m_e32'), + ('ViT-B-16', 'openai'), + ('ViT-B-16', 'laion400m_e31'), + ('ViT-B-16', 'laion400m_e32'), + ('ViT-B-16-plus-240', 'laion400m_e31'), + ('ViT-B-16-plus-240', 'laion400m_e32'), + ('ViT-L-14', 'openai'), + ('ViT-L-14', 'laion400m_e31'), + ('ViT-L-14', 'laion400m_e32'), + ('ViT-L-14', 'laion2b_s32b_b82k'), + ('ViT-L-14-336', 'openai'), + ('ViT-H-14', 'laion2b_s32b_b79k'), + ('ViT-g-14', 'laion2b_s12b_b42k'), + ('ViT-bigG-14', 'laion2b_s39b_b160k'), + ('roberta-ViT-B-32', 'laion2b_s12b_b32k'), + ('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'), + ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'), + ('coca_ViT-B-32', 'laion2B-s13B-b90k'), + ('coca_ViT-B-32', 'mscoco_finetuned_laion2B-s13B-b90k'), # finetuned models lose contrastive capabilities + ('coca_ViT-L-14', 'laion2B-s13B-b90k'), + ('coca_ViT-L-14', 'mscoco_finetuned_laion2B-s13B-b90k'),] # finetuned models lose contrastive capabilities + +>>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') +``` +### Model distillation + +You can distill from a pre-trained by using `--distill-model` and `--distill-pretrained` to specify the model you'd like to distill from. +For instance, to distill from OpenAI ViT-L/14 use `--distill-model ViT-L-14 --distill-pretrained openai`. + +### Gradient accumulation + +To simulate larger batches use `--accum-freq k`. If per gpu batch size, `--batch-size`, is `m`, then the effective batch size will be `k * m * num_gpus`. + +When increasing `--accum-freq` from its default of 1, samples/s will remain approximately constant (batch size will double, as will time-per-batch). It is recommended to use other features to reduce batch size such as `--grad-checkpointing --local-loss --gather-with-grad` before increasing `--accum-freq`. `--accum-freq` can be used in addition to these features. + +Instead of 1 forward pass per example, there are now 2 forward passes per-example. However, the first is done with `torch.no_grad`. + +There is some additional GPU memory required --- the features and data from all `m` batches are stored in memory. + +There are also `m` loss computations instead of the usual 1. + +For more information see Cui et al. (https://arxiv.org/abs/2112.09331) or Pham et al. (https://arxiv.org/abs/2111.10050). + +### Support for remote loading/training + +It is always possible to resume directly from a remote file, e.g., a file in an s3 bucket. Just set `--resume s3:// `. +This will work with any filesystem supported by `fsspec`. + +It is also possible to train `open_clip` models while continuously backing up to s3. This can help to avoid slow local file systems. + +Say that your node has a local ssd `/scratch`, an s3 bucket `s3://`. + +In that case, set `--logs /scratch` and `--remote-sync s3://`. Then, a background process will sync `/scratch/` to `s3:///`. After syncing, the background process will sleep for `--remote-sync-frequency` seconds, which defaults to 5 minutes. + +There is also experimental support for syncing to other remote file systems, not just s3. To do so, specify `--remote-sync-protocol fsspec`. However, this is currently very slow and not recommended. + +Also, to optionally avoid saving too many checkpoints locally when using these features, you can use `--delete-previous-checkpoint` which deletes the previous checkpoint after saving a new one. + +Note: if you are using this feature with `--resume latest`, there are a few warnings. First, use with `--save-most-recent` is not supported. Second, only `s3` is supported. Finally, since the sync happens in the background, it is possible that the most recent checkpoint may not be finished syncing to the remote. + +### Pushing Models to Hugging Face Hub + +The module `open_clip.push_to_hf_hub` includes helpers for pushing models /w weights and config to the HF Hub. + +The tool can be run from command line, ex: +`pytorch -m open_clip.push_to_hf_hub --model convnext_large_d_320 --pretrained /train/checkpoints/epoch_12.pt --repo-id laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft` + +## Scaling trends + +The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples. + + + +## Why are low-accuracy CLIP models interesting? + +**TL;DR:** CLIP models have high effective robustness, even at small scales. + +CLIP models are particularly intriguing because they are more robust to natural distribution shifts (see Section 3.3 in the [CLIP paper](https://arxiv.org/abs/2103.00020)). +This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis +and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet validation set with distribution shift) accuracy on the y-axis. +Standard training denotes training on the ImageNet train set and the CLIP zero-shot models +are shown as stars. + +![CLIP scatter plot](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/effective_robustness.png) + +As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644) and [Miller et al., 2021](https://arxiv.org/abs/2107.04649), the in-distribution +and out-of-distribution accuracies of models trained on ImageNet follow a predictable linear trend (the red line in the above plot). *Effective robustness* +quantifies robustness as accuracy beyond this baseline, i.e., how far a model lies above the red line. Ideally a model would not suffer from distribution shift and fall on the y = x line ([trained human labelers are within a percentage point of the y = x line](http://proceedings.mlr.press/v119/shankar20c.html)). + +Even though the CLIP models trained with +this codebase achieve much lower accuracy than those trained by OpenAI, our models still lie on the same +trend of improved effective robustness (the purple line). Therefore, we can study what makes +CLIP robust without requiring industrial-scale compute. + +For more information on effective robustness, please see: + +- [Recht et al., 2019](https://arxiv.org/abs/1902.10811). +- [Taori et al., 2020](https://arxiv.org/abs/2007.00644). +- [Miller et al., 2021](https://arxiv.org/abs/2107.04649). + +To know more about the factors that contribute to CLIP's robustness refer to [Fang et al., 2022](https://arxiv.org/abs/2205.01397). + +## Acknowledgments + +We gratefully acknowledge the Gauss Centre for Supercomputing e.V. (www.gauss-centre.eu) for funding this part of work by providing computing time through the John von Neumann Institute for Computing (NIC) on the GCS Supercomputer JUWELS Booster at Jülich Supercomputing Centre (JSC). + +## The Team + +Current development of this repository is led by [Ross Wightman](https://rwightman.com/), [Cade Gordon](http://cadegordon.io/), and [Vaishaal Shankar](http://vaishaal.com/). + +The original version of this repository is from a group of researchers at UW, Google, Stanford, Amazon, Columbia, and Berkeley. + +[Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hongseok Namkoong](https://hsnamkoong.github.io/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/) + +Special thanks to [Jong Wook Kim](https://jongwook.kim/) and [Alec Radford](https://github.com/Newmu) for help with reproducing CLIP! + +## Citing + +If you found this repository useful, please consider citing: +```bibtex +@software{ilharco_gabriel_2021_5143773, + author = {Ilharco, Gabriel and + Wortsman, Mitchell and + Wightman, Ross and + Gordon, Cade and + Carlini, Nicholas and + Taori, Rohan and + Dave, Achal and + Shankar, Vaishaal and + Namkoong, Hongseok and + Miller, John and + Hajishirzi, Hannaneh and + Farhadi, Ali and + Schmidt, Ludwig}, + title = {OpenCLIP}, + month = jul, + year = 2021, + note = {If you use this software, please cite it as below.}, + publisher = {Zenodo}, + version = {0.1}, + doi = {10.5281/zenodo.5143773}, + url = {https://doi.org/10.5281/zenodo.5143773} +} +``` + +```bibtex +@inproceedings{Radford2021LearningTV, + title={Learning Transferable Visual Models From Natural Language Supervision}, + author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever}, + booktitle={ICML}, + year={2021} +} +``` + +```bibtex +@inproceedings{schuhmann2022laionb, + title={{LAION}-5B: An open large-scale dataset for training next generation image-text models}, + author={Christoph Schuhmann and + Romain Beaumont and + Richard Vencu and + Cade W Gordon and + Ross Wightman and + Mehdi Cherti and + Theo Coombes and + Aarush Katta and + Clayton Mullis and + Mitchell Wortsman and + Patrick Schramowski and + Srivatsa R Kundurthy and + Katherine Crowson and + Ludwig Schmidt and + Robert Kaczmarczyk and + Jenia Jitsev}, + booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, + year={2022}, + url={https://openreview.net/forum?id=M3Y74vmsMcY} +} +``` + +[![DOI](https://zenodo.org/badge/390536799.svg)](https://zenodo.org/badge/latestdoi/390536799) diff --git a/openclip/open_clip_torch.egg-info/SOURCES.txt b/openclip/open_clip_torch.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ee2c6e4a11015a840ddcf0515730b949410c105 --- /dev/null +++ b/openclip/open_clip_torch.egg-info/SOURCES.txt @@ -0,0 +1,103 @@ +LICENSE +MANIFEST.in +README.md +setup.py +src/open_clip/__init__.py +src/open_clip/bpe_simple_vocab_16e6.txt.gz +src/open_clip/coca_model.py +src/open_clip/constants.py +src/open_clip/factory.py +src/open_clip/generation_utils.py +src/open_clip/hf_configs.py +src/open_clip/hf_model.py +src/open_clip/loss.py +src/open_clip/model.py +src/open_clip/modified_resnet.py +src/open_clip/openai.py +src/open_clip/pretrained.py +src/open_clip/push_to_hf_hub.py +src/open_clip/timm_model.py +src/open_clip/tokenizer.py +src/open_clip/transform.py +src/open_clip/transformer.py +src/open_clip/utils.py +src/open_clip/version.py +src/open_clip/model_configs/RN101-quickgelu.json +src/open_clip/model_configs/RN101.json +src/open_clip/model_configs/RN50-quickgelu.json +src/open_clip/model_configs/RN50.json +src/open_clip/model_configs/RN50x16.json +src/open_clip/model_configs/RN50x4.json +src/open_clip/model_configs/RN50x64.json +src/open_clip/model_configs/ViT-B-16-plus-240.json +src/open_clip/model_configs/ViT-B-16-plus.json +src/open_clip/model_configs/ViT-B-16.json +src/open_clip/model_configs/ViT-B-32-plus-256.json +src/open_clip/model_configs/ViT-B-32-quickgelu.json +src/open_clip/model_configs/ViT-B-32.json +src/open_clip/model_configs/ViT-H-14.json +src/open_clip/model_configs/ViT-H-16.json +src/open_clip/model_configs/ViT-L-14-280.json +src/open_clip/model_configs/ViT-L-14-336.json +src/open_clip/model_configs/ViT-L-14.json +src/open_clip/model_configs/ViT-L-16-320.json +src/open_clip/model_configs/ViT-L-16.json +src/open_clip/model_configs/ViT-M-16-alt.json +src/open_clip/model_configs/ViT-M-16.json +src/open_clip/model_configs/ViT-M-32-alt.json +src/open_clip/model_configs/ViT-M-32.json +src/open_clip/model_configs/ViT-S-16-alt.json +src/open_clip/model_configs/ViT-S-16.json +src/open_clip/model_configs/ViT-S-32-alt.json +src/open_clip/model_configs/ViT-S-32.json +src/open_clip/model_configs/ViT-bigG-14.json +src/open_clip/model_configs/ViT-e-14.json +src/open_clip/model_configs/ViT-g-14.json +src/open_clip/model_configs/coca_ViT-B-32.json +src/open_clip/model_configs/coca_ViT-L-14.json +src/open_clip/model_configs/coca_base.json +src/open_clip/model_configs/coca_roberta-ViT-B-32.json +src/open_clip/model_configs/convnext_base.json +src/open_clip/model_configs/convnext_base_w.json +src/open_clip/model_configs/convnext_base_w_320.json +src/open_clip/model_configs/convnext_large.json +src/open_clip/model_configs/convnext_large_d.json +src/open_clip/model_configs/convnext_large_d_320.json +src/open_clip/model_configs/convnext_small.json +src/open_clip/model_configs/convnext_tiny.json +src/open_clip/model_configs/convnext_xlarge.json +src/open_clip/model_configs/convnext_xxlarge.json +src/open_clip/model_configs/convnext_xxlarge_320.json +src/open_clip/model_configs/mt5-base-ViT-B-32.json +src/open_clip/model_configs/mt5-xl-ViT-H-14.json +src/open_clip/model_configs/roberta-ViT-B-32.json +src/open_clip/model_configs/swin_base_patch4_window7_224.json +src/open_clip/model_configs/vit_medium_patch16_gap_256.json +src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json +src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json +src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json +src/open_clip_torch.egg-info/PKG-INFO +src/open_clip_torch.egg-info/SOURCES.txt +src/open_clip_torch.egg-info/dependency_links.txt +src/open_clip_torch.egg-info/requires.txt +src/open_clip_torch.egg-info/top_level.txt +src/training/__init__.py +src/training/data.py +src/training/distributed.py +src/training/file_utils.py +src/training/imagenet_zeroshot_data.py +src/training/logger.py +src/training/main.py +src/training/params.py +src/training/precision.py +src/training/profile.py +src/training/scheduler.py +src/training/train.py +src/training/zero_shot.py +tests/test_download_pretrained.py +tests/test_hf_model.py +tests/test_inference.py +tests/test_inference_simple.py +tests/test_num_shards.py +tests/test_training_simple.py +tests/test_wds.py \ No newline at end of file diff --git a/openclip/open_clip_torch.egg-info/dependency_links.txt b/openclip/open_clip_torch.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/openclip/open_clip_torch.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/openclip/open_clip_torch.egg-info/requires.txt b/openclip/open_clip_torch.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..70684cdb8e8fd406aa77e0c731fcca8f2f0c2ffb --- /dev/null +++ b/openclip/open_clip_torch.egg-info/requires.txt @@ -0,0 +1,23 @@ +torch>=1.9.0 +torchvision +regex +ftfy +tqdm +huggingface_hub +sentencepiece +protobuf<4 +timm + +[training] +torch>=1.9.0 +torchvision +webdataset>=0.2.5 +regex +ftfy +tqdm +pandas +braceexpand +huggingface_hub +transformers +timm +fsspec diff --git a/openclip/open_clip_torch.egg-info/top_level.txt b/openclip/open_clip_torch.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..f8963f3cd18c8d3e9c49c462b48aeb4315e52cf4 --- /dev/null +++ b/openclip/open_clip_torch.egg-info/top_level.txt @@ -0,0 +1,2 @@ +open_clip +training diff --git a/openclip/training/.gitignore b/openclip/training/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..333c1e910a3e2bef1b9d0d4587392627d8388974 --- /dev/null +++ b/openclip/training/.gitignore @@ -0,0 +1 @@ +logs/ diff --git a/openclip/training/__init__.py b/openclip/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openclip/training/__pycache__/__init__.cpython-310.pyc b/openclip/training/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6618d971bc1ee802d9ae1712d168c95598da7a72 Binary files /dev/null and b/openclip/training/__pycache__/__init__.cpython-310.pyc differ diff --git a/openclip/training/__pycache__/__init__.cpython-38.pyc b/openclip/training/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..510e89dbfda4d7848dd048c96f5affb54d77cc36 Binary files /dev/null and b/openclip/training/__pycache__/__init__.cpython-38.pyc differ diff --git a/openclip/training/__pycache__/data.cpython-310.pyc b/openclip/training/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebb5af0d69c93e6c9e722d86e2c81739dd470f57 Binary files /dev/null and b/openclip/training/__pycache__/data.cpython-310.pyc differ diff --git a/openclip/training/__pycache__/data.cpython-38.pyc b/openclip/training/__pycache__/data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..781e3d290bf50f04f77e636b079325463a825065 Binary files /dev/null and b/openclip/training/__pycache__/data.cpython-38.pyc differ diff --git a/openclip/training/data.py b/openclip/training/data.py new file mode 100644 index 0000000000000000000000000000000000000000..af68bf6bc9cf7502eb83c2ec0ffaa15bc4b51af3 --- /dev/null +++ b/openclip/training/data.py @@ -0,0 +1,3205 @@ +import ast +import json +import logging +import math +import os +import random +import sys +import time +import braceexpand +from dataclasses import dataclass +from multiprocessing import Value +import cv2 + +import numpy as np +import pandas as pd +import torch +import torchvision.datasets as datasets +import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info +from torch.utils.data.distributed import DistributedSampler +from webdataset.filters import _shuffle +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample +from torchvision import transforms + +import io +import PIL +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +import cv2 +import math +import json +import random +import seaborn as sns + +def vis_landmark_on_img(img, shape, linewidth=8): + ''' + Visualize landmark on images. + ''' + + def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth): + for i in idx_list: + cv2.line(img, (shape[i][0], shape[i][1]), (shape[i + 1][0], shape[i + 1][1]), color, lineWidth) + if (loop): + cv2.line(img, (shape[idx_list[0]][0], shape[idx_list[0]][1]), + (shape[idx_list[-1] + 1][0], shape[idx_list[-1] + 1][1]), color, lineWidth) + + draw_curve(list(range(0, 16)), color=(255, 144, 25)) # jaw + draw_curve(list(range(17, 21)), color=(50, 205, 50)) # eye brow + draw_curve(list(range(22, 26)), color=(50, 205, 50)) + draw_curve(list(range(27, 35)), color=(208, 224, 63)) # nose + draw_curve(list(range(36, 41)), loop=True, color=(71, 99, 255)) # eyes + draw_curve(list(range(42, 47)), loop=True, color=(71, 99, 255)) + draw_curve(list(range(48, 59)), loop=True, color=(238, 130, 238)) # mouth + draw_curve(list(range(60, 67)), loop=True, color=(238, 130, 238)) + + return img.astype("uint8") + +def imshow_keypoints(img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False, + height=None, + width=None): + """Draw keypoints and links on an image. + + Args: + img (str or Tensor): The image to draw poses on. If an image array + is given, id will be modified in-place. + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.array[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + """ + + # img = mmcv.imread(img) + # img_h, img_w, _ = img.shape + if img is None: + img = np.zeros((height, width, 3), dtype=np.uint8) + img_h, img_w = height, width + else: + img_h, img_w, _ = img.shape + + for kpts in pose_result: + + kpts = np.array(kpts, copy=False) + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + for kid, kpt in enumerate(kpts): + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + if kpt_score > kpt_score_thr: + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle(img_copy, (int(x_coord), int(y_coord)), + radius, color, -1) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, + color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + # if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 + # and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w + # and pos2[1] > 0 and pos2[1] < img_h + # and kpts[sk[0], 2] > kpt_score_thr + # and kpts[sk[1], 2] > kpt_score_thr): + if (kpts[sk[0], 2] > kpt_score_thr + and kpts[sk[1], 2] > kpt_score_thr): + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees( + math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = thickness + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), int(angle), 0, + 360, 1) + cv2.fillConvexPoly(img_copy, polygon, color) + # transparency = max( + # 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) + transparency = 1 + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img + +def imshow_keypoints_body(img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False, + height=None, + width=None): + """Draw keypoints and links on an image. + + Args: + img (str or Tensor): The image to draw poses on. If an image array + is given, id will be modified in-place. + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.array[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + """ + + # img = mmcv.imread(img) + # img_h, img_w, _ = img.shape + if img is None: + img = np.zeros((height, width, 3), dtype=np.uint8) + img_h, img_w = height, width + else: + img_h, img_w, _ = img.shape + + for kpts in pose_result: + + kpts = np.array(kpts, copy=False) + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + for kid, kpt in enumerate(kpts): + if kid in [17, 18, 19, 20, 21, 22]: + continue + if kid in [13, 14, 15, 16]: + if kpt[0] > min(kpts[23:91, 0]) and kpt[0] < max(kpts[23:91, 0]) and kpt[1] > min(kpts[23:91, 1]) and kpt[1] < max(kpts[23:91, 1]): + continue + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + if kpt_score > kpt_score_thr: + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle(img_copy, (int(x_coord), int(y_coord)), + radius, color, -1) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, + color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + for sk_id, sk in enumerate(skeleton): + if sk[0] in [17, 18, 19, 20, 21, 22] or sk[1] in [17, 18, 19, 20, 21, 22]: + continue + if sk[0] in [13, 14, 15, 16]: + if kpts[sk[0], 0] > min(kpts[23:91, 0]) and kpts[sk[0], 0] < max(kpts[23:91, 0]) and kpts[sk[0], 1] > min(kpts[23:91, 1]) and kpts[sk[0], 1] < max(kpts[23:91, 1]): + continue + if sk[1] in [13, 14, 15, 16]: + if kpts[sk[1], 0] > min(kpts[23:91, 0]) and kpts[sk[1], 0] < max(kpts[23:91, 0]) and kpts[sk[1], 1] > min(kpts[23:91, 1]) and kpts[sk[1], 1] < max(kpts[23:91, 1]): + continue + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + # if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 + # and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w + # and pos2[1] > 0 and pos2[1] < img_h + # and kpts[sk[0], 2] > kpt_score_thr + # and kpts[sk[1], 2] > kpt_score_thr): + if (kpts[sk[0], 2] > kpt_score_thr + and kpts[sk[1], 2] > kpt_score_thr): + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees( + math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = thickness + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), int(angle), 0, + 360, 1) + cv2.fillConvexPoly(img_copy, polygon, color) + # transparency = max( + # 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) + transparency = 1 + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img + +def imshow_keypoints_whole(img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False, + height=None, + width=None): + """Draw keypoints and links on an image. + + Args: + img (str or Tensor): The image to draw poses on. If an image array + is given, id will be modified in-place. + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.array[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + """ + + # img = mmcv.imread(img) + # img_h, img_w, _ = img.shape + if img is None: + img = np.zeros((height, width, 3), dtype=np.uint8) + img_h, img_w = height, width + else: + img_h, img_w, _ = img.shape + + for kpts in pose_result: + + kpts = np.array(kpts, copy=False) + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + for kid, kpt in enumerate(kpts): + if kid in [17, 18, 19, 20, 21, 22]: + continue + if kid in [13, 14, 15, 16]: + if kpt[0] > min(kpts[23:91, 0]) and kpt[0] < max(kpts[23:91, 0]) and kpt[1] > min(kpts[23:91, 1]) and kpt[1] < max(kpts[23:91, 1]): + continue + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + if kpt_score > kpt_score_thr: + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle(img_copy, (int(x_coord), int(y_coord)), + radius, color, -1) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, + color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + for sk_id, sk in enumerate(skeleton): + if sk[0] in [17, 18, 19, 20, 21, 22] or sk[1] in [17, 18, 19, 20, 21, 22]: + continue + if sk[0] in [13, 14, 15, 16]: + if kpts[sk[0], 0] > min(kpts[23:91, 0]) and kpts[sk[0], 0] < max(kpts[23:91, 0]) and kpts[sk[0], 1] > min(kpts[23:91, 1]) and kpts[sk[0], 1] < max(kpts[23:91, 1]): + continue + if sk[1] in [13, 14, 15, 16]: + if kpts[sk[1], 0] > min(kpts[23:91, 0]) and kpts[sk[1], 0] < max(kpts[23:91, 0]) and kpts[sk[1], 1] > min(kpts[23:91, 1]) and kpts[sk[1], 1] < max(kpts[23:91, 1]): + continue + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + # if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 + # and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w + # and pos2[1] > 0 and pos2[1] < img_h + # and kpts[sk[0], 2] > kpt_score_thr + # and kpts[sk[1], 2] > kpt_score_thr): + if (kpts[sk[0], 2] > kpt_score_thr + and kpts[sk[1], 2] > kpt_score_thr): + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees( + math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = thickness + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), int(angle), 0, + 360, 1) + cv2.fillConvexPoly(img_copy, polygon, color) + # transparency = max( + # 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) + transparency = 1 + cv2.addWeighted( + img_copy, + transparency, + img, + 1 - transparency, + 0, + dst=img) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img + +def draw_whole_body_skeleton( + img, + pose, + radius=4, + thickness=1, + kpt_score_thr=0.3, + height=None, + width=None, + ): + palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], + [230, 230, 0], [255, 153, 255], [153, 204, 255], + [255, 102, 255], [255, 51, 255], [102, 178, 255], + [51, 153, 255], [255, 153, 153], [255, 102, 102], + [255, 51, 51], [153, 255, 153], [102, 255, 102], + [51, 255, 51], [0, 255, 0], [0, 0, 255], + [255, 0, 0], [255, 255, 255]]) + + # below are for the whole body keypoints + skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], + [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], + [8, 10], [1, 2], [0, 1], [0, 2], + [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18], + [15, 19], [16, 20], [16, 21], [16, 22], [91, 92], + [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], + [97, 98], [98, 99], [91, 100], [100, 101], [101, 102], + [102, 103], [91, 104], [104, 105], [105, 106], + [106, 107], [91, 108], [108, 109], [109, 110], + [110, 111], [112, 113], [113, 114], [114, 115], + [115, 116], [112, 117], [117, 118], [118, 119], + [119, 120], [112, 121], [121, 122], [122, 123], + [123, 124], [112, 125], [125, 126], [126, 127], + [127, 128], [112, 129], [129, 130], [130, 131], + [131, 132]] + + pose_link_color = palette[[ + 0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16 + ] + [16, 16, 16, 16, 16, 16] + [ + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, + 16 + ] + [ + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, + 16 + ]] + pose_kpt_color = palette[ + [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + + [0, 0, 0, 0, 0, 0] + [19] * (68 + 42)] + + draw = imshow_keypoints_whole(img, pose, skeleton, + kpt_score_thr=0.3, + pose_kpt_color=pose_kpt_color, + pose_link_color=pose_link_color, + radius=radius, + thickness=thickness, + show_keypoint_weight=True, + height=height, + width=width) + return draw + +def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10): + humansd_skeleton=[ + [0,0,1], + [1,0,2], + [2,1,3], + [3,2,4], + [4,3,5], + [5,4,6], + [6,5,7], + [7,6,8], + [8,7,9], + [9,8,10], + [10,5,11], + [11,6,12], + [12,11,13], + [13,12,14], + [14,13,15], + [15,14,16], + ] + # humansd_skeleton_width=10 + humansd_color=sns.color_palette("hls", len(humansd_skeleton)) + + def plot_kpts(img_draw, kpts, color, edgs,width): + for idx, kpta, kptb in edgs: + if kpts[kpta,2]>mmpose_detection_thresh and \ + kpts[kptb,2]>mmpose_detection_thresh : + line_color = tuple([int(255*color_i) for color_i in color[idx]]) + + cv2.line(img_draw, (int(kpts[kpta,0]),int(kpts[kpta,1])), (int(kpts[kptb,0]),int(kpts[kptb,1])), line_color,width) + cv2.circle(img_draw, (int(kpts[kpta,0]),int(kpts[kpta,1])), width//2, line_color, -1) + cv2.circle(img_draw, (int(kpts[kptb,0]),int(kpts[kptb,1])), width//2, line_color, -1) + + if image is None: + pose_image = np.zeros((height, width, 3), dtype=np.uint8) + else: + pose_image = np.array(image, dtype=np.uint8) + for person_i in range(len(pose)): + if np.sum(pose[person_i])>0: + plot_kpts(pose_image, pose[person_i],humansd_color,humansd_skeleton,humansd_skeleton_width) + + return pose_image + +def draw_controlnet_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None): + if image is None: + canvas = np.zeros((height, width, 3), dtype=np.uint8) + else: + H, W, C = image.shape + canvas = np.array(image, dtype=np.uint8) + + for pose_i in range(len(pose)): + present_pose=pose[pose_i] + candidate=[ + [present_pose[0,0],present_pose[0,1],present_pose[0,2],0], + [(present_pose[6,0]+present_pose[5,0])/2,(present_pose[6,1]+present_pose[5,1])/2,(present_pose[6,2]+present_pose[5,2])/2,1] if present_pose[6,2]>mmpose_detection_thresh and present_pose[5,2]>mmpose_detection_thresh else [-1,-1,0,1], + [present_pose[6,0],present_pose[6,1],present_pose[6,2],2], + [present_pose[8,0],present_pose[8,1],present_pose[8,2],3], + [present_pose[10,0],present_pose[10,1],present_pose[10,2],4], + [present_pose[5,0],present_pose[5,1],present_pose[5,2],5], + [present_pose[7,0],present_pose[7,1],present_pose[7,2],6], + [present_pose[9,0],present_pose[9,1],present_pose[9,2],7], + [present_pose[12,0],present_pose[12,1],present_pose[12,2],8], + [present_pose[14,0],present_pose[14,1],present_pose[14,2],9], + [present_pose[16,0],present_pose[16,1],present_pose[16,2],10], + [present_pose[11,0],present_pose[11,1],present_pose[11,2],11], + [present_pose[13,0],present_pose[13,1],present_pose[13,2],12], + [present_pose[15,0],present_pose[15,1],present_pose[15,2],13], + [present_pose[2,0],present_pose[2,1],present_pose[2,2],14], + [present_pose[1,0],present_pose[1,1],present_pose[1,2],15], + [present_pose[4,0],present_pose[4,1],present_pose[4,2],16], + [present_pose[3,0],present_pose[3,1],present_pose[3,2],17], + ] + stickwidth = 4 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + if candidate[limbSeq[i][0]-1][2]>mmpose_detection_thresh and candidate[limbSeq[i][1]-1][2]>mmpose_detection_thresh: + Y=[candidate[limbSeq[i][1]-1][0],candidate[limbSeq[i][0]-1][0]] + X=[candidate[limbSeq[i][1]-1][1],candidate[limbSeq[i][0]-1][1]] + + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cur_canvas = canvas.copy() + cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + + for i in range(18): + if candidate[i][2]>mmpose_detection_thresh: + x, y = candidate[i][0:2] + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + +def draw_body_skeleton( + img, + pose, + radius=4, + thickness=1, + kpt_score_thr=0.3, + height=None, + width=None, + ): + palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], + [230, 230, 0], [255, 153, 255], [153, 204, 255], + [255, 102, 255], [255, 51, 255], [102, 178, 255], + [51, 153, 255], [255, 153, 153], [255, 102, 102], + [255, 51, 51], [153, 255, 153], [102, 255, 102], + [51, 255, 51], [0, 255, 0], [0, 0, 255], + [255, 0, 0], [255, 255, 255]]) + + # below are for the body keypoints + # skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], + # [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], + # [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], + # [3, 5], [4, 6]] + skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], + [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], + [8, 10], [3, 4], + [3, 5], [4, 6]] + + # pose_link_color = palette[[ + # 0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16 + # ]] + # pose_kpt_color = palette[[ + # 16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0 + # ]] + pose_link_color = palette[[ + 12, 16, 1, 5, 9, 13, 19, 15, 11, 7, 3, 18, 14, 8, 0 + ]] + pose_kpt_color = palette[[ + 19, 15, 11, 7, 3, 18, 14, 10, 6, 2, 17, 13, 9, 5, 1, 16, 12 + ]] + draw = imshow_keypoints_body(img, pose, skeleton, + kpt_score_thr=0.3, + pose_kpt_color=pose_kpt_color, + pose_link_color=pose_link_color, + radius=radius, + thickness=thickness, + show_keypoint_weight=True, + height=height, + width=width) + return draw + +def draw_face_skeleton( + img, + pose, + radius=4, + thickness=1, + kpt_score_thr=0.3, + height=None, + width=None, + ): + palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], + [230, 230, 0], [255, 153, 255], [153, 204, 255], + [255, 102, 255], [255, 51, 255], [102, 178, 255], + [51, 153, 255], [255, 153, 153], [255, 102, 102], + [255, 51, 51], [153, 255, 153], [102, 255, 102], + [51, 255, 51], [0, 255, 0], [0, 0, 255], + [255, 0, 0], [255, 255, 255]]) + + # below are for the face keypoints + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 68] + kpt_score_thr = 0 + + draw = imshow_keypoints(img, pose, skeleton, + kpt_score_thr=kpt_score_thr, + pose_kpt_color=pose_kpt_color, + pose_link_color=pose_link_color, + radius=radius, + thickness=thickness, + show_keypoint_weight=True, + height=height, + width=width) + return draw + +def draw_hand_skeleton( + img, + pose, + radius=4, + thickness=1, + kpt_score_thr=0.3, + height=None, + width=None, + ): + palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], + [230, 230, 0], [255, 153, 255], [153, 204, 255], + [255, 102, 255], [255, 51, 255], [102, 178, 255], + [51, 153, 255], [255, 153, 153], [255, 102, 102], + [255, 51, 51], [153, 255, 153], [102, 255, 102], + [51, 255, 51], [0, 255, 0], [0, 0, 255], + [255, 0, 0], [255, 255, 255]]) + + # hand option 1 + skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], + [7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13], + [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], + [18, 19], [19, 20]] + + pose_link_color = palette[[ + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, + 16 + ]] + pose_kpt_color = palette[[ + 0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, + 16, 16 + ]] + + # # hand option 2 + # skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9], + # [9, 10], [10, 11], [12, 13], [13, 14], [14, 15], + # [16, 17], [17, 18], [18, 19], [3, 20], [7, 20], + # [11, 20], [15, 20], [19, 20]] + + # pose_link_color = palette[[ + # 0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, + # 16 + # ]] + # pose_kpt_color = palette[[ + # 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, + # 16, 0 + # ]] + + draw = imshow_keypoints(img, pose, skeleton, + kpt_score_thr=kpt_score_thr, + pose_kpt_color=pose_kpt_color, + pose_link_color=pose_link_color, + radius=radius, + thickness=thickness, + show_keypoint_weight=True, + height=height, + width=width) + return draw + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): + logging.debug(f'Loading csv data from {input_filename}.') + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug('Done loading data.') + + self.tokenize = tokenizer + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = self.tokenize([str(self.captions[idx])])[0] + return images, texts + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def expand_urls(urls, weights=None): + if weights is None: + expanded_urls = wds.shardlists.expand_urls(urls) + return expanded_urls, None + if isinstance(urls, str): + urllist = urls.split("::") + weights = weights.split('::') + assert len(weights) == len(urllist), f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." + weights = [float(weight) for weight in weights] + all_urls, all_weights = [], [] + for url, weight in zip(urllist, weights): + expanded_url = list(braceexpand.braceexpand(url)) + expanded_weights = [weight for _ in expanded_url] + all_urls.extend(expanded_url) + all_weights.extend(expanded_weights) + return all_urls, all_weights + else: + all_urls = list(urls) + return all_urls, weights + + +def get_dataset_size(shards): + shards_list, _ = expand_urls(shards) + dir_path = os.path.dirname(shards_list[0]) + sizes_filename = os.path.join(dir_path, 'sizes.json') + len_filename = os.path.join(dir_path, '__len__') + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, 'r')) + total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, 'r').read()) + else: + total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # CC3M (train): 2905954 + # CC12M: 10968539 + # LAION-400M: 407332084 + # LAION-2B (english): 2170337258 + num_shards = len(shards_list) + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype('int') + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader=dataloader, sampler=sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption_or_no_image(sample): + has_caption = ('txt' in sample) + has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) + return has_caption and has_image + + +def filter_no_image_or_no_ldmk(sample): + has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) + has_ldmk = ('ldmk' in sample) + return has_image and has_ldmk + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def pytorch_worker_seed(increment=0): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour using the seed already created for pytorch dataloader workers if it exists + seed = worker_info.seed + if increment: + # space out seed increments so they can't overlap across workers in different iterations + seed += increment * max(1, worker_info.num_workers) + return seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers + seed = pytorch_worker_seed(epoch) + else: + # This seed to be deterministic AND the same across all nodes/workers in each epoch + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + weights=None, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls, weights = expand_urls(urls, weights) + self.urls = urls + self.weights = weights + if self.weights is not None: + assert len(self.urls) == len(self.weights), f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + if self.deterministic: + # reset seed w/ epoch if deterministic + if self.worker_seed is None: + # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id + seed = pytorch_worker_seed(epoch) + else: + seed = self.worker_seed() + epoch + self.rng.seed(seed) + for _ in range(self.nshards): + if self.weights is None: + yield dict(url=self.rng.choice(self.urls)) + else: + yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) + +def get_wds_dataset_filter(args, preprocess_img): + input_shards = args.train_data + assert input_shards is not None + + pipeline = [wds.SimpleShardList(input_shards)] + + def replicate_img(sample): + import copy + sample["original"] = copy.copy(sample["image"]) + return sample + + def decode_byte_to_rgb(sample): + # import io + # import PIL + # from PIL import ImageFile + # ImageFile.LOAD_TRUNCATED_IMAGES = True + with io.BytesIO(sample["image"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + sample["image"] = img + return sample + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (512, 512)) + sample["image"] = image + return sample + + # at this point we have an iterator over all the shards + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + tarfile_to_samples_nothrow, + wds.select(filter_no_caption_or_no_image), + # wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map(replicate_img), + wds.map(decode_byte_to_rgb), + # wds.map_dict(image=preprocess_img, text=lambda x: x.encode('utf-8'), \ + # __key__=lambda x: x.encode('utf-8'), __url__=lambda x: x.encode('utf-8')), + wds.map_dict(image=preprocess_img), + wds.to_tuple("original", "image", "text", "__key__", "__url__", "json"), + wds.batched(args.batch_size, partial=True) + ]) + + dataset = wds.DataPipeline(*pipeline) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + drop_last=False + ) + + return DataInfo(dataloader=dataloader) + +def get_wds_dataset_cond_face(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + + def preprocess_image(sample): + # print(main_args.resolution, main_args.center_crop, main_args.random_flip) + resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) + sample["image"] = resize_transform(sample["image"]) + sample["ldmk"] = resize_transform(sample["ldmk"]) + transform_list = [] + image_height, image_width = sample["image"].height, sample["image"].width + i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() + + if main_args.center_crop or not is_train: + transform_list.append(transforms.CenterCrop(main_args.resolution)) + else: + if image_height < main_args.resolution or image_width < main_args.resolution: + raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution and image_height == main_args.resolution: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) + + if is_train and torch.rand(1) < 0.5: + transform_list.append(transforms.RandomHorizontalFlip(p=1.)) + transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + train_transforms = transforms.Compose(transform_list) + sample["image"] = train_transforms(sample["image"]) + sample["ldmk"] = train_transforms(sample["ldmk"]) + return sample + + # def extract_ldmk(sample): + # image_height, image_width = sample["image"].height, sample["image"].width + # preds = fa.get_landmarks(np.array(sample["image"])) + # lands = [] + # if preds is not None: + # for pred in preds: + # land = pred.reshape(-1, 3)[:,:2].astype(int) + # lands.append(land) + + # lms_color_map = np.zeros(shape=(image_height, image_width, 3)).astype("uint8") + # if len(lands) > 0: + # for land in lands: + # lms_color_map = vis_landmark_on_img(lms_color_map, land) + # # print(lms_color_map.shape) + # sample["ldmk"] = Image.fromarray(lms_color_map) + # return sample + + def visualize_ldmk(sample): + image_height, image_width = sample["image"].height, sample["image"].width + lands = np.frombuffer(sample["ldmk"], dtype=np.float32) + lms_color_map = np.zeros(shape=(image_height, image_width, 3)).astype("uint8") + if len(lands) > 0: + lands = lands.reshape(-1, 68, 3).astype(int) + for i in range(lands.shape[0]): + lms_color_map = vis_landmark_on_img(lms_color_map, lands[i]) + # print(lms_color_map.shape) + sample["ldmk"] = Image.fromarray(lms_color_map) + return sample + + def filter_ldmk_none(sample): + return not (sample["ldmk"] == -1).all() + + def filter_low_res(sample): + if filter_lowres: + string_json = sample["json"].decode('utf-8') + dict_json = json.loads(string_json) + if "height" in dict_json.keys() and "width" in dict_json.keys(): + min_length = min(dict_json["height"], dict_json["width"]) + return min_length >= main_args.resolution + else: + return True + return True + + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.select(filter_low_res), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + # wds.map(extract_ldmk), + wds.map(visualize_ldmk), + wds.map(preprocess_image), + wds.select(filter_ldmk_none), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map_dict( + text=lambda text: tokenizer(text, \ + max_length=tokenizer.model_max_length, \ + padding="max_length", truncation=True, \ + return_tensors='pt')['input_ids'], + ), + wds.to_tuple("image", "text", "ldmk"), + # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +def get_wds_dataset_depth(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False, filter_mface=False, filter_wpose=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + + def decode_image(sample): + + with io.BytesIO(sample["omni_depth"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + sample["depth"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (512, 512)) + sample["depth"] = image + + return sample + + train_transforms = transforms.Compose( + [ + transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(main_args.resolution) if main_args.center_crop else transforms.RandomCrop(main_args.resolution), + transforms.RandomHorizontalFlip() if main_args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def filter_depth_none(sample): + return not (sample["depth"] == -1).all() + + def filter_low_res(sample): + if filter_lowres: + string_json = sample["json"].decode('utf-8') + dict_json = json.loads(string_json) + if "height" in dict_json.keys() and "width" in dict_json.keys(): + min_length = min(dict_json["height"], dict_json["width"]) + return min_length >= main_args.resolution + else: + return True + return True + + def filter_multi_face(sample): + if filter_mface: + face_kp = np.frombuffer(sample["face_kp"], dtype=np.float32).reshape(-1, 98, 2) + if face_kp.shape[0] > 1: + return False + + return True + + def filter_whole_skeleton(sample): + if filter_wpose: + height, width = sample["image"].height, sample["image"].width + body_kp = np.frombuffer(sample["body_kp"], dtype=np.float32).reshape(17, 2) + if (body_kp[:, 0] > 0).all() and (body_kp[:, 0] < width).all() and (body_kp[:, 1] > 0).all() and (body_kp[:, 1] < height).all(): + return True + else: + return False + + return True + + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.select(filter_multi_face), + wds.select(filter_low_res), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.select(filter_whole_skeleton), + wds.map(decode_image), + wds.map_dict(depth=train_transforms), + wds.select(filter_depth_none), + # wds.map_dict(depth=train_transforms, text=lambda text: tokenizer(text)[0]), + wds.map_dict( + text=lambda text: tokenizer(text, \ + max_length=tokenizer.model_max_length, \ + padding="max_length", truncation=True, \ + return_tensors='pt')['input_ids']), + wds.to_tuple("depth", "text"), + # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +def get_wds_dataset_depth2canny(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + + def decode_image(sample): + with io.BytesIO(sample["omni_depth"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + sample["depth"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (512, 512)) + sample["depth"] = image + + return sample + + def add_canny(sample): + canny = np.array(sample["image"]) + + low_threshold = 100 + high_threshold = 200 + + canny = cv2.Canny(canny, low_threshold, high_threshold) + canny = canny[:, :, None] + canny = np.concatenate([canny, canny, canny], axis=2) + sample["canny"] = Image.fromarray(canny) + return sample + + def preprocess_image(sample): + # print(main_args.resolution, main_args.center_crop, main_args.random_flip) + if grid_dnc: + resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) + else: + resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) + sample["image"] = resize_transform(sample["image"]) + sample["canny"] = resize_transform(sample["canny"]) + sample["depth"] = resize_transform(sample["depth"]) + transform_list = [] + image_height, image_width = sample["image"].height, sample["image"].width + if grid_dnc: + i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() + else: + i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() + + if main_args.center_crop or not is_train: + if grid_dnc: + transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) + else: + transform_list.append(transforms.CenterCrop(main_args.resolution)) + else: + if grid_dnc: + if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: + raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) + else: + if image_height < main_args.resolution or image_width < main_args.resolution: + raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution and image_height == main_args.resolution: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) + + if is_train and torch.rand(1) < 0.5: + transform_list.append(transforms.RandomHorizontalFlip(p=1.)) + transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + train_transforms = transforms.Compose(transform_list) + sample["image"] = train_transforms(sample["image"]) + sample["canny"] = train_transforms(sample["canny"]) + sample["depth"] = train_transforms(sample["depth"]) + return sample + + def random_mask(sample): + + if is_train and dropout: + random_num = torch.rand(1) + + if random_num < 0.1: + sample["depth"] = torch.ones_like(sample["depth"]) * (-1) + + return sample + + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.map(decode_image), + wds.map(add_canny), + wds.map(preprocess_image), + wds.map(random_mask), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map_dict( + text=lambda text: tokenizer(text, \ + max_length=tokenizer.model_max_length, \ + padding="max_length", truncation=True, \ + return_tensors='pt')['input_ids'], + ), + wds.to_tuple("canny", "text", "depth"), + # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +def get_wds_dataset_depth2normal(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + + def decode_image(sample): + with io.BytesIO(sample["omni_normal"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + sample["normal"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (512, 512)) + sample["normal"] = image + + with io.BytesIO(sample["omni_depth"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + sample["depth"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (512, 512)) + sample["depth"] = image + + return sample + + def preprocess_image(sample): + # print(main_args.resolution, main_args.center_crop, main_args.random_flip) + if grid_dnc: + resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) + else: + resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) + sample["image"] = resize_transform(sample["image"]) + sample["normal"] = resize_transform(sample["normal"]) + sample["depth"] = resize_transform(sample["depth"]) + transform_list = [] + image_height, image_width = sample["image"].height, sample["image"].width + if grid_dnc: + i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() + else: + i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() + + if main_args.center_crop or not is_train: + if grid_dnc: + transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) + else: + transform_list.append(transforms.CenterCrop(main_args.resolution)) + else: + if grid_dnc: + if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: + raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) + else: + if image_height < main_args.resolution or image_width < main_args.resolution: + raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution and image_height == main_args.resolution: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) + + if is_train and torch.rand(1) < 0.5: + transform_list.append(transforms.RandomHorizontalFlip(p=1.)) + transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + train_transforms = transforms.Compose(transform_list) + sample["image"] = train_transforms(sample["image"]) + sample["normal"] = train_transforms(sample["normal"]) + sample["depth"] = train_transforms(sample["depth"]) + return sample + + def random_mask(sample): + + if is_train and dropout: + random_num = torch.rand(1) + + if random_num < 0.1: + sample["depth"] = torch.ones_like(sample["depth"]) * (-1) + + return sample + + def filter_low_res(sample): + if filter_lowres: + string_json = sample["json"].decode('utf-8') + dict_json = json.loads(string_json) + if "height" in dict_json.keys() and "width" in dict_json.keys(): + min_length = min(dict_json["height"], dict_json["width"]) + return min_length >= main_args.resolution + else: + return True + return True + + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.select(filter_low_res), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.map(decode_image), + wds.map(preprocess_image), + wds.map(random_mask), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map_dict( + text=lambda text: tokenizer(text, \ + max_length=tokenizer.model_max_length, \ + padding="max_length", truncation=True, \ + return_tensors='pt')['input_ids'], + ), + wds.to_tuple("normal", "text", "depth"), + # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +# def get_wds_dataset_cond_sdxl(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): +# input_shards = args.train_data if is_train else args.val_data +# assert input_shards is not None +# resampled = getattr(args, 'dataset_resampled', False) and is_train + +# num_samples, num_shards = get_dataset_size(input_shards) +# if not num_samples: +# if is_train: +# num_samples = args.train_num_samples +# if not num_samples: +# raise RuntimeError( +# 'Currently, number of dataset samples must be specified for training dataset. ' +# 'Please specify via `--train-num-samples` if no dataset length info present.') +# else: +# num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + +# shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + +# if resampled: +# pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] +# else: +# assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." +# pipeline = [wds.SimpleShardList(input_shards)] + +# # at this point we have an iterator over all the shards +# if is_train: +# if not resampled: +# pipeline.extend([ +# detshuffle2( +# bufsize=_SHARD_SHUFFLE_SIZE, +# initial=_SHARD_SHUFFLE_INITIAL, +# seed=args.seed, +# epoch=shared_epoch, +# ), +# wds.split_by_node, +# wds.split_by_worker, +# ]) +# pipeline.extend([ +# # at this point, we have an iterator over the shards assigned to each worker at each node +# tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), +# wds.shuffle( +# bufsize=_SAMPLE_SHUFFLE_SIZE, +# initial=_SAMPLE_SHUFFLE_INITIAL, +# ), +# ]) +# else: +# pipeline.extend([ +# wds.split_by_worker, +# # at this point, we have an iterator over the shards assigned to each worker +# wds.tarfile_to_samples(handler=log_and_continue), +# ]) + +# def pose2img(sample): +# height, width = sample["image"].height, sample["image"].width +# min_length = min(height, width) +# radius_body = max(int(4. * min_length / main_args.resolution), 4) +# thickness_body = max(int(2. * min_length / main_args.resolution), 2) +# radius_face = max(int(2. * min_length / main_args.resolution), 2) +# thickness_face = max(int(1. * min_length / main_args.resolution), 1) +# radius_hand = max(int(2. * min_length / main_args.resolution), 2) +# thickness_hand = max(int(1. * min_length / main_args.resolution), 1) +# # if "getty" in sample["__url__"]: +# # radius_body *= 4 +# # thickness_body *= 4 +# # radius_face *= 4 +# # thickness_face *= 4 +# # radius_hand *= 4 +# # thickness_hand *= 4 +# body_kp = np.frombuffer(sample["body_kp"], dtype=np.float32).reshape(17, 2) +# body_kpconf = np.frombuffer(sample["body_kpconf"], dtype=np.float32) +# body_all = np.concatenate([body_kp, body_kpconf[:, np.newaxis]], axis=1) +# body_all = body_all[np.newaxis, ...] +# body_draw = draw_body_skeleton( +# img=None, +# pose=body_all, +# radius=radius_body, +# thickness=thickness_body, +# height=height, +# width=width +# ) +# body_draw = Image.fromarray(body_draw) + +# face_kp = np.frombuffer(sample["face_kp"], dtype=np.float32).reshape(-1, 98, 2) +# face_kpconf = np.frombuffer(sample["face_kpconf"], dtype=np.float32).reshape(-1, 98) +# face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) + +# face_draw = draw_face_skeleton( +# # img=np.array(img), +# img=None, +# pose=face_all, +# radius=radius_face, +# thickness=thickness_face, +# height=height, +# width=width +# ) +# face_draw = Image.fromarray(face_draw) + +# hand_kp = np.frombuffer(sample["hand_kp"], dtype=np.float32).reshape(-1, 21, 2) +# hand_kpconf = np.frombuffer(sample["hand_kpconf"], dtype=np.float32).reshape(-1, 21) +# hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) + +# hand_draw = draw_hand_skeleton( +# # img=np.array(img), +# img=None, +# pose=hand_all, +# radius=radius_hand, +# thickness=thickness_hand, +# height=height, +# width=width +# ) +# hand_draw = Image.fromarray(hand_draw) + +# sample["body"] = body_draw +# sample["face"] = face_draw +# sample["hand"] = hand_draw +# return sample + +# def decode_image(sample): +# with io.BytesIO(sample["omni_normal"]) as stream: +# try: +# img = PIL.Image.open(stream) +# img.load() +# img = img.convert("RGB") +# sample["normal"] = img +# except: +# print("A broken image is encountered, replace w/ a placeholder") +# image = Image.new('RGB', (512, 512)) +# sample["normal"] = image + +# with io.BytesIO(sample["omni_depth"]) as stream: +# try: +# img = PIL.Image.open(stream) +# img.load() +# img = img.convert("RGB") +# sample["depth"] = img +# except: +# print("A broken image is encountered, replace w/ a placeholder") +# image = Image.new('RGB', (512, 512)) +# sample["depth"] = image + +# return sample + +# def add_canny(sample): +# canny = np.array(sample["image"]) + +# low_threshold = 100 +# high_threshold = 200 + +# canny = cv2.Canny(canny, low_threshold, high_threshold) +# canny = canny[:, :, None] +# canny = np.concatenate([canny, canny, canny], axis=2) +# sample["canny"] = Image.fromarray(canny) +# return sample + +# def decode_text(sample): +# sample["blip"] = sample["blip"].decode("utf-8") +# sample["blip_raw"] = sample["blip"] +# sample["text_raw"] = sample["text"] +# return sample + +# def augment_text(sample): +# if is_train and string_concat: +# sample["text"] = sample["text"] + " " + sample["blip"] +# if is_train and string_substitute: +# sample["text"] = sample["blip"] +# return sample + +# def preprocess_image(sample): +# # print(main_args.resolution, main_args.center_crop, main_args.random_flip) +# if grid_dnc: +# resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) +# else: +# resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) +# sample["image"] = resize_transform(sample["image"]) +# sample["normal"] = resize_transform(sample["normal"]) +# sample["depth"] = resize_transform(sample["depth"]) +# sample["canny"] = resize_transform(sample["canny"]) +# sample["body"] = resize_transform(sample["body"]) +# sample["face"] = resize_transform(sample["face"]) +# sample["hand"] = resize_transform(sample["hand"]) +# transform_list = [] +# image_height, image_width = sample["image"].height, sample["image"].width +# if grid_dnc: +# i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() +# j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() +# else: +# i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() +# j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() + +# if main_args.center_crop or not is_train: +# sample["description"]["crop_tl_h"] = (image_height - main_args.resolution) // 2 +# sample["description"]["crop_tl_w"] = (image_width - main_args.resolution) // 2 +# if grid_dnc: +# transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) +# else: +# transform_list.append(transforms.CenterCrop(main_args.resolution)) +# else: +# if grid_dnc: +# if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: +# raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") + +# elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: +# i, j = 0, 0 + +# transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) +# else: +# if image_height < main_args.resolution or image_width < main_args.resolution: +# raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") + +# elif image_width == main_args.resolution and image_height == main_args.resolution: +# i, j = 0, 0 + +# transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) +# sample["description"]["crop_tl_h"] = i +# sample["description"]["crop_tl_w"] = j + +# if is_train and torch.rand(1) < 0.5: +# transform_list.append(transforms.RandomHorizontalFlip(p=1.)) +# transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) +# train_transforms = transforms.Compose(transform_list) +# sample["image"] = train_transforms(sample["image"]) +# sample["normal"] = train_transforms(sample["normal"]) +# sample["depth"] = train_transforms(sample["depth"]) +# sample["canny"] = train_transforms(sample["canny"]) +# sample["body"] = train_transforms(sample["body"]) +# sample["face"] = train_transforms(sample["face"]) +# sample["hand"] = train_transforms(sample["hand"]) +# return sample + +# def random_mask(sample): + +# if is_train and dropout: +# random_num = torch.rand(1) + +# if random_num < 0.1: +# sample["normal"] = torch.ones_like(sample["normal"]) * (-1) +# sample["depth"] = torch.ones_like(sample["depth"]) * (-1) +# sample["canny"] = torch.ones_like(sample["canny"]) * (-1) +# sample["body"] = torch.ones_like(sample["body"]) * (-1) +# sample["face"] = torch.ones_like(sample["face"]) * (-1) +# sample["hand"] = torch.ones_like(sample["hand"]) * (-1) +# elif random_num > 0.9: +# pass +# else: +# if torch.rand(1) < 0.5: +# sample["normal"] = torch.ones_like(sample["normal"]) * (-1) +# if torch.rand(1) < 0.5: +# sample["depth"] = torch.ones_like(sample["depth"]) * (-1) +# if torch.rand(1) < 0.8: +# sample["canny"] = torch.ones_like(sample["canny"]) * (-1) +# if torch.rand(1) < 0.5: +# sample["body"] = torch.ones_like(sample["body"]) * (-1) +# if torch.rand(1) < 0.5: +# sample["face"] = torch.ones_like(sample["face"]) * (-1) +# if torch.rand(1) < 0.2: +# sample["hand"] = torch.ones_like(sample["hand"]) * (-1) + +# return sample + +# def make_grid_dnc(sample): + +# if grid_dnc: +# resized_image = transforms.functional.resize(sample["image"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) +# resized_depth = transforms.functional.resize(sample["depth"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) +# resized_normal = transforms.functional.resize(sample["normal"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) +# resized_canny = transforms.functional.resize(sample["canny"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) +# grid = torch.cat([torch.cat([resized_image, resized_depth], dim=2), +# torch.cat([resized_normal, resized_canny], dim=2)], dim=1) +# assert grid.shape[1] == main_args.resolution and grid.shape[2] == main_args.resolution +# sample["image"] = grid +# return sample + +# def filter_low_res(sample): +# if main_args.filter_res is None: +# main_args.filter_res = main_args.resolution +# if filter_lowres: +# string_json = sample["json"].decode('utf-8') +# dict_json = json.loads(string_json) +# if "height" in dict_json.keys() and "width" in dict_json.keys(): +# min_length = min(dict_json["height"], dict_json["width"]) +# return min_length >= main_args.filter_res +# else: +# return True +# return True + +# def add_original_hw(sample): +# image_height, image_width = sample["image"].height, sample["image"].width +# sample["description"] = {"h": image_height, "w": image_width} +# return sample + +# def add_description(sample): +# # string_json = sample["json"].decode('utf-8') +# # dict_json = json.loads(string_json) +# dict_json = sample["json"] +# if "height" in dict_json.keys() and "width" in dict_json.keys(): +# sample["description"]["h"] = dict_json["height"] +# sample["description"]["w"] = dict_json["width"] + +# return sample + +# pipeline.extend([ +# wds.select(filter_no_caption_or_no_image), +# wds.select(filter_low_res), +# wds.decode("pilrgb", handler=log_and_continue), +# wds.rename(image="jpg;png;jpeg;webp", text="txt"), +# wds.map(add_original_hw), +# wds.map(decode_text), +# wds.map(augment_text), +# wds.map(pose2img), +# wds.map(decode_image), +# wds.map(add_canny), +# wds.map(preprocess_image), +# wds.map(make_grid_dnc), +# wds.map(random_mask), +# wds.map(add_description), +# # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), +# # wds.map_dict( +# # text=lambda text: tokenizer(text, \ +# # max_length=tokenizer.model_max_length, \ +# # padding="max_length", truncation=True, \ +# # return_tensors='pt')['input_ids'], +# # blip=lambda blip: tokenizer(blip, \ +# # max_length=tokenizer.model_max_length, \ +# # padding="max_length", truncation=True, \ +# # return_tensors='pt')['input_ids'] +# # ), +# wds.to_tuple("image", "text", "text_raw", "blip", "blip_raw", "body", "face", "hand", "normal", "depth", "canny", "description"), +# # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), +# wds.batched(args.batch_size, partial=not is_train) +# ]) + +# dataset = wds.DataPipeline(*pipeline) + +# if is_train: +# if not resampled: +# assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' +# # roll over and repeat a few samples to get same number of full batches on each node +# round_fn = math.floor if floor else math.ceil +# global_batch_size = args.batch_size * args.world_size +# num_batches = round_fn(num_samples / global_batch_size) +# num_workers = max(1, args.workers) +# num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker +# num_batches = num_worker_batches * num_workers +# num_samples = num_batches * global_batch_size +# dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this +# else: +# # last batches are partial, eval is done on single (master) node +# num_batches = math.ceil(num_samples / args.batch_size) + +# dataloader = wds.WebLoader( +# dataset, +# batch_size=None, +# shuffle=False, +# num_workers=args.workers, +# persistent_workers=True, +# ) + +# # FIXME not clear which approach is better, with_epoch before vs after dataloader? +# # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 +# # if is_train: +# # # roll over and repeat a few samples to get same number of full batches on each node +# # global_batch_size = args.batch_size * args.world_size +# # num_batches = math.ceil(num_samples / global_batch_size) +# # num_workers = max(1, args.workers) +# # num_batches = math.ceil(num_batches / num_workers) * num_workers +# # num_samples = num_batches * global_batch_size +# # dataloader = dataloader.with_epoch(num_batches) +# # else: +# # # last batches are partial, eval is done on single (master) node +# # num_batches = math.ceil(num_samples / args.batch_size) + +# # add meta-data to dataloader instance for convenience +# dataloader.num_batches = num_batches +# dataloader.num_samples = num_samples + +# return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +def get_wds_dataset_cond(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False, filter_res=512, filter_mface=False, filter_wpose=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + + def pose2img(sample, scale): + height, width = sample["image"].height, sample["image"].width + # min_length = min(height, width) + # radius_body = int(4. * min_length / main_args.resolution) + # thickness_body = int(4. * min_length / main_args.resolution) + # radius_face = int(1.5 * min_length / main_args.resolution) + # thickness_face = int(2. * min_length / main_args.resolution) + # radius_hand = int(1.5 * min_length / main_args.resolution) + # thickness_hand = int(2. * min_length / main_args.resolution) + # if "getty" in sample["__url__"]: + # radius_body *= 4 + # thickness_body *= 4 + # radius_face *= 4 + # thickness_face *= 4 + # radius_hand *= 4 + # thickness_hand *= 4 + + try: + location = np.frombuffer(sample["location"], dtype=np.float32) + body_kp = np.frombuffer(sample["new_i_body_kp"], dtype=np.float32).reshape(-1, 17, 2) + x_coord = (body_kp[:, :, 0] - location[0]) / location[2] * location[7] + y_coord = (body_kp[:, :, 1] - location[1]) / location[3] * location[8] + body_kp = np.stack([x_coord, y_coord], axis=2) + body_kp = body_kp * scale + # body_kp[:, :, 0] -= j + # body_kp[:, :, 1] -= i + body_kpconf = np.frombuffer(sample["new_i_body_kp_score"], dtype=np.float32).reshape(-1, 17) + body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) + except: + body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) + body_kp = body_kp * scale + # body_kp[:, :, 0] -= j + # body_kp[:, :, 1] -= i + body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) + body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) + + # body_ratio = 0. + # for i_body in range(body_kp.shape[0]): + # body_ratio = max((np.max(body_kp[i_body, :, 0]) - np.min(body_kp[i_body, :, 0])) / min_length, body_ratio) + # print(body_ratio) + # body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) + # body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) + # body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) + # body_draw = draw_controlnet_skeleton(image=None, pose=body_all, height=height, width=width) + # body_draw = draw_humansd_skeleton(image=None, pose=body_all, height=height, width=width, humansd_skeleton_width=int(10. * body_ratio * min_length / main_args.resolution)) + body_draw = draw_humansd_skeleton( + # image=np.array(sample["image"]), + image=None, + pose=body_all, + height=height, + width=width, + humansd_skeleton_width=int(10 * main_args.resolution / 512), + ) + # body_draw = draw_body_skeleton( + # img=None, + # pose=body_all, + # radius=radius_body, + # thickness=thickness_body, + # height=height, + # width=width + # ) + body_draw = Image.fromarray(body_draw) + + try: + location = np.frombuffer(sample["location"], dtype=np.float32) + face_kp = np.frombuffer(sample["new_i_face_kp"], dtype=np.float32).reshape(-1, 68, 2) + x_coord = (face_kp[:, :, 0] - location[0]) / location[2] * location[7] + y_coord = (face_kp[:, :, 1] - location[1]) / location[3] * location[8] + face_kp = np.stack([x_coord, y_coord], axis=2) + face_kp = face_kp * scale + # face_kp[:, :, 0] -= j + # face_kp[:, :, 1] -= i + face_kpconf = np.frombuffer(sample["new_i_face_kp_score"], dtype=np.float32).reshape(-1, 68) + face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) + except: + face_kp = np.frombuffer(sample["new_face_kp"], dtype=np.float32).reshape(-1, 68, 2) + face_kp = face_kp * scale + # face_kp[:, :, 0] -= j + # face_kp[:, :, 1] -= i + face_kpconf = np.frombuffer(sample["new_face_kp_score"], dtype=np.float32).reshape(-1, 68) + face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) + + face_draw = draw_face_skeleton( + # img=np.array(sample["image"]), + img=None, + pose=face_all, + # radius=radius_face, + # thickness=thickness_face, + height=height, + width=width, + ) + face_draw = Image.fromarray(face_draw) + + try: + location = np.frombuffer(sample["location"], dtype=np.float32) + hand_kp = np.frombuffer(sample["new_i_hand_kp"], dtype=np.float32).reshape(-1, 21, 2) + x_coord = (hand_kp[:, :, 0] - location[0]) / location[2] * location[7] + y_coord = (hand_kp[:, :, 1] - location[1]) / location[3] * location[8] + hand_kp = np.stack([x_coord, y_coord], axis=2) + hand_kp = hand_kp * scale + # hand_kp[:, :, 0] -= j + # hand_kp[:, :, 1] -= i + hand_kpconf = np.frombuffer(sample["new_i_hand_kp_score"], dtype=np.float32).reshape(-1, 21) + hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) + except: + hand_kp = np.frombuffer(sample["new_hand_kp"], dtype=np.float32).reshape(-1, 21, 2) + hand_kp = hand_kp * scale + # hand_kp[:, :, 0] -= j + # hand_kp[:, :, 1] -= i + hand_kpconf = np.frombuffer(sample["new_hand_kp_score"], dtype=np.float32).reshape(-1, 21) + hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) + + hand_draw = draw_hand_skeleton( + # img=np.array(sample["image"]), + img=None, + pose=hand_all, + # radius=radius_hand, + # thickness=thickness_hand, + height=height, + width=width, + ) + hand_draw = Image.fromarray(hand_draw) + + # whole_kp = np.frombuffer(sample["new_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) + # whole_kpconf = np.frombuffer(sample["new_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) + # whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) + + try: + location = np.frombuffer(sample["location"], dtype=np.float32) + whole_kp = np.frombuffer(sample["new_i_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) + x_coord = (whole_kp[:, :, 0] - location[0]) / location[2] * location[7] + y_coord = (whole_kp[:, :, 1] - location[1]) / location[3] * location[8] + whole_kp = np.stack([x_coord, y_coord], axis=2) + whole_kp = whole_kp * scale + # whole_kp[:, :, 0] -= j + # whole_kp[:, :, 1] -= i + whole_kpconf = np.frombuffer(sample["new_i_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) + whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) + except: + whole_kp = np.frombuffer(sample["new_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) + whole_kp = whole_kp * scale + # whole_kp[:, :, 0] -= j + # whole_kp[:, :, 1] -= i + whole_kpconf = np.frombuffer(sample["new_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) + whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) + + whole_draw = draw_whole_body_skeleton( + # img=np.array(sample["image"]), + img=None, + pose=whole_all, + # radius=radius_body, + # thickness=thickness_body, + height=height, + width=width, + ) + whole_draw = Image.fromarray(whole_draw) + + sample["body"] = body_draw + sample["face"] = face_draw + sample["hand"] = hand_draw + + if main_args.change_whole_to_body: + sample["whole"] = body_draw + else: + sample["whole"] = whole_draw + + return sample + + def decode_image(sample): + with io.BytesIO(sample["omni_normal"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) + sample["normal"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (main_args.resolution, main_args.resolution)) + sample["normal"] = image + + with io.BytesIO(sample["omni_depth"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) + sample["depth"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (main_args.resolution, main_args.resolution)) + sample["depth"] = image + + with io.BytesIO(sample["midas_depth"]) as stream: + try: + img = PIL.Image.open(stream) + img.load() + img = img.convert("RGB") + img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) + sample["midas_depth"] = img + except: + print("A broken image is encountered, replace w/ a placeholder") + image = Image.new('RGB', (main_args.resolution, main_args.resolution)) + sample["midas_depth"] = image + + return sample + + def add_canny(sample): + canny = np.array(sample["image"]) + + low_threshold = 100 + high_threshold = 200 + + canny = cv2.Canny(canny, low_threshold, high_threshold) + canny = canny[:, :, None] + canny = np.concatenate([canny, canny, canny], axis=2) + sample["canny"] = Image.fromarray(canny) + return sample + + def decode_text(sample): + try: + sample["blip"] = sample["blip"].decode("utf-8") + sample["blip_raw"] = sample["blip"] + except: + sample["blip"] = sample["text"] + sample["blip_raw"] = sample["text"].encode("utf-8") + sample["text_raw"] = sample["text"] + return sample + + def augment_text(sample): + if is_train and string_concat: + sample["text"] = sample["text"] + " " + sample["blip"] + if is_train and string_substitute: + if main_args.rv_prompt: + sample["text"] = "RAW photo, " + sample["blip"] + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" + else: + sample["text"] = sample["blip"] + return sample + + def dropout_text(sample): + if is_train: + try: + random_num = torch.rand(1) + if random_num < main_args.dropout_text: + sample["text"] = sample["text_raw"] = "" + except: + pass + + return sample + + def preprocess_image(sample): + # print(main_args.resolution, main_args.center_crop, main_args.random_flip) + if grid_dnc: + resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) + else: + resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) + + scale = main_args.resolution * 1. / min(sample["image"].height, sample["image"].width) + + sample["image"] = resize_transform(sample["image"]) + sample["normal"] = resize_transform(sample["normal"]) + sample["depth"] = resize_transform(sample["depth"]) + sample["midas_depth"] = resize_transform(sample["midas_depth"]) + sample["canny"] = resize_transform(sample["canny"]) + # sample["body"] = resize_transform(sample["body"]) + # sample["face"] = resize_transform(sample["face"]) + # sample["hand"] = resize_transform(sample["hand"]) + # sample["whole"] = resize_transform(sample["whole"]) + transform_list = [] + image_height, image_width = sample["image"].height, sample["image"].width + if grid_dnc: + i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() + else: + i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() + j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() + + if main_args.center_crop or not is_train: + sample["description"]["crop_tl_h"] = i = (image_height - main_args.resolution) // 2 + sample["description"]["crop_tl_w"] = j = (image_width - main_args.resolution) // 2 + if grid_dnc: + transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) + else: + transform_list.append(transforms.CenterCrop(main_args.resolution)) + else: + if grid_dnc: + if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: + raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) + else: + if image_height < main_args.resolution or image_width < main_args.resolution: + raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") + + elif image_width == main_args.resolution and image_height == main_args.resolution: + i, j = 0, 0 + + transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) + sample["description"]["crop_tl_h"] = i + sample["description"]["crop_tl_w"] = j + + sample = pose2img(sample, scale) + + if is_train and torch.rand(1) < 0.5: + transform_list.append(transforms.RandomHorizontalFlip(p=1.)) + transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + train_transforms = transforms.Compose(transform_list) + sample["image"] = train_transforms(sample["image"]) + sample["normal"] = train_transforms(sample["normal"]) + sample["depth"] = train_transforms(sample["depth"]) + sample["midas_depth"] = train_transforms(sample["midas_depth"]) + sample["canny"] = train_transforms(sample["canny"]) + + sample["body"] = train_transforms(sample["body"]) + sample["face"] = train_transforms(sample["face"]) + sample["hand"] = train_transforms(sample["hand"]) + sample["whole"] = train_transforms(sample["whole"]) + return sample + + def random_mask(sample): + sample["normal_ori"] = sample["normal"].clone() + sample["depth_ori"] = sample["depth"].clone() + sample["midas_depth_ori"] = sample["midas_depth"].clone() + sample["canny_ori"] = sample["canny"].clone() + sample["body_ori"] = sample["body"].clone() + sample["face_ori"] = sample["face"].clone() + sample["hand_ori"] = sample["hand"].clone() + sample["whole_ori"] = sample["whole"].clone() + + mask_list = [] + + if is_train and dropout: + random_num = torch.rand(1) + + if random_num < 0.15: + sample["normal"] = torch.ones_like(sample["normal"]) * (-1) + sample["depth"] = torch.ones_like(sample["depth"]) * (-1) + sample["midas_depth"] = torch.ones_like(sample["midas_depth"]) * (-1) + sample["canny"] = torch.ones_like(sample["canny"]) * (-1) + sample["body"] = torch.ones_like(sample["body"]) * (-1) + sample["face"] = torch.ones_like(sample["face"]) * (-1) + sample["hand"] = torch.ones_like(sample["hand"]) * (-1) + sample["whole"] = torch.ones_like(sample["whole"]) * (-1) + mask_list = ["normal", "depth", "midas_depth", "canny", "body", "face", "hand", "whole"] + elif random_num > 0.9: + pass + else: + if torch.rand(1) < 0.5: + sample["normal"] = torch.ones_like(sample["normal"]) * (-1) + mask_list.append("normal") + if torch.rand(1) < 0.5: + sample["depth"] = torch.ones_like(sample["depth"]) * (-1) + mask_list.append("depth") + if torch.rand(1) < 0.5: + sample["midas_depth"] = torch.ones_like(sample["midas_depth"]) * (-1) + mask_list.append("midas_depth") + if torch.rand(1) < 0.8: + sample["canny"] = torch.ones_like(sample["canny"]) * (-1) + mask_list.append("canny") + if torch.rand(1) < 0.5: + sample["body"] = torch.ones_like(sample["body"]) * (-1) + mask_list.append("body") + if torch.rand(1) < 0.5: + sample["face"] = torch.ones_like(sample["face"]) * (-1) + mask_list.append("face") + if torch.rand(1) < 0.2: + sample["hand"] = torch.ones_like(sample["hand"]) * (-1) + mask_list.append("hand") + if torch.rand(1) < 0.5: + sample["whole"] = torch.ones_like(sample["whole"]) * (-1) + mask_list.append("whole") + + sample["normal_dt"] = sample["normal"].clone() + sample["depth_dt"] = sample["depth"].clone() + sample["midas_depth_dt"] = sample["midas_depth"].clone() + sample["canny_dt"] = sample["canny"].clone() + sample["body_dt"] = sample["body"].clone() + sample["face_dt"] = sample["face"].clone() + sample["hand_dt"] = sample["hand"].clone() + sample["whole_dt"] = sample["whole"].clone() + + mask_list = [x for x in mask_list if x in main_args.cond_type] + + if len(mask_list) > 0: + target = random.choice(mask_list) + sample[target + "_dt"] = sample[target + "_ori"].clone() + else: + if len(main_args.cond_type) > 0: + target = random.choice(main_args.cond_type) + sample[target + "_dt"] = torch.ones_like(sample[target]) * (-1) + + return sample + + def make_grid_dnc(sample): + + if grid_dnc: + resized_image = transforms.functional.resize(sample["image"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) + resized_depth = transforms.functional.resize(sample["depth"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) + resized_normal = transforms.functional.resize(sample["normal"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) + resized_canny = transforms.functional.resize(sample["body"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) + # resized_canny = transforms.functional.resize(sample["canny"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) + grid = torch.cat([torch.cat([resized_image, resized_depth], dim=2), + torch.cat([resized_normal, resized_canny], dim=2)], dim=1) + assert grid.shape[1] == main_args.resolution and grid.shape[2] == main_args.resolution + sample["image"] = grid + return sample + + def filter_low_res(sample): + if main_args.filter_res is None: + main_args.filter_res = main_args.resolution + if filter_lowres: + # string_json = sample["json"].decode('utf-8') + # dict_json = json.loads(string_json) + dict_json = sample["json"] + if "height" in dict_json.keys() and "width" in dict_json.keys(): + min_length = min(dict_json["height"], dict_json["width"]) + return min_length >= main_args.filter_res + else: + min_length = min(sample["image"].height, sample["image"].width) + return min_length >= main_args.filter_res + return True + + def filter_watermark(sample): + if main_args.filter_wm: + if sample["description"]["watermark"] >= 100: + return False + return True + + def add_original_hw(sample): + image_height, image_width = sample["image"].height, sample["image"].width + sample["description"] = {"h": image_height, "w": image_width} + return sample + + def add_description(sample): + # string_json = sample["json"].decode('utf-8') + # dict_json = json.loads(string_json) + try: + dict_json = sample["json"] + if "height" in dict_json.keys() and "width" in dict_json.keys(): + sample["description"]["h"] = dict_json["height"] + sample["description"]["w"] = dict_json["width"] + + # try: + if "coyo" in sample["__url__"]: + sample["description"]["aes"] = torch.tensor(sample["json"]["aesthetic_score_laion_v2"] * 1e2) + sample["description"]["watermark"] = torch.tensor(sample["json"]["watermark_score"] * 1e3) + elif "laion" in sample["__url__"]: + sample["description"]["aes"] = torch.tensor(np.frombuffer(sample["aesthetic_score_laion_v2"], dtype=np.float32) * 1e2) + sample["description"]["watermark"] = torch.tensor(np.frombuffer(sample["watermark_score"], dtype=np.float32) * 1e3) + elif "getty" in sample["__url__"]: + sample["description"]["aes"] = torch.tensor(np.frombuffer(sample["aesthetic_score_laion_v2"], dtype=np.float32) * 1e2) + sample["description"]["watermark"] = torch.tensor(float(sample["json"]["display_sizes"][-1]["is_watermarked"] or 0) * 1e3) + elif "fake" in sample["__url__"]: + sample["description"]["aes"] = torch.tensor(random.uniform(5.5, 6.0) * 1e2) + sample["description"]["watermark"] = torch.tensor(random.uniform(0., 0.1) * 1e3) + except: + # sample["description"]["h"] = + # sample["description"]["w"] = + sample["description"]["aes"] = torch.tensor(random.uniform(5.5, 6.0) * 1e2) + sample["description"]["watermark"] = torch.tensor(random.uniform(0., 0.1) * 1e3) + # except: + # sample["description"]["aes"] = 0. + # sample["description"]["watermark"] = 0. + + return sample + + def filter_multi_face(sample): + if filter_mface: + face_kp = np.frombuffer(sample["new_face_kp"], dtype=np.float32).reshape(-1, 68, 2) + if face_kp.shape[0] > 1: + return False + + return True + + def filter_whole_skeleton(sample): + if filter_wpose: + height, width = sample["image"].height, sample["image"].width + area = height * width + body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) + body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) + if (body_kp.shape[0] == 1) and (body_kpconf > 0.5).all() and (body_kp[0, :15, 0] > 0).all() \ + and (body_kp[0, :15, 0] < width).all() and (body_kp[0, :15, 1] > 0).all() and \ + (body_kp[0, :15, 1] < height).all(): + x_min = max(np.amin(body_kp[0, :, 0]), 0) + x_max = min(np.amax(body_kp[0, :, 0]), width) + y_min = max(np.amin(body_kp[0, :, 1]), 0) + y_max = min(np.amax(body_kp[0, :, 1]), height) + if (x_max - x_min) * (y_max - y_min) / area > 0.2: + return True + else: + return False + else: + return False + + return True + + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.select(filter_multi_face), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.select(filter_whole_skeleton), + wds.select(filter_low_res), + wds.map(add_original_hw), + wds.map(decode_text), + wds.map(augment_text), + wds.map(dropout_text), + # wds.map(pose2img), + wds.map(decode_image), + wds.map(add_canny), + wds.map(preprocess_image), + wds.map(make_grid_dnc), + wds.map(random_mask), + wds.map(add_description), + wds.select(filter_watermark), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map_dict( + text=lambda text: tokenizer(text, \ + max_length=tokenizer.model_max_length, \ + padding="max_length", truncation=True, \ + return_tensors='pt')['input_ids'], + blip=lambda blip: tokenizer(blip, \ + max_length=tokenizer.model_max_length, \ + padding="max_length", truncation=True, \ + return_tensors='pt')['input_ids'] + ), + wds.to_tuple("image", "text", "text_raw", "blip", "blip_raw", \ + "body", "face", "hand", "normal", "depth", "midas_depth", "canny", "whole", "description", \ + "body_ori", "face_ori", "hand_ori", "normal_ori", "depth_ori", "midas_depth_ori", "canny_ori", "whole_ori", \ + "body_dt", "face_dt", "hand_dt", "normal_dt", "depth_dt", "midas_depth_dt", "canny_dt", "whole_dt"), + # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +def get_wds_dataset_img(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp"), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map_dict(image=preprocess_img), + wds.to_tuple("image"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + +def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] + else: + assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors='pt')['input_ids']), + wds.to_tuple("image", "text"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + tokenizer=tokenizer + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +class SyntheticDataset(Dataset): + + def __init__(self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None): + self.transform = transform + self.image_size = image_size + self.caption = caption + self.image = Image.new('RGB', image_size) + self.dataset_size = dataset_size + + self.preprocess_txt = lambda text: tokenizer(text)[0] + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + if self.transform is not None: + image = self.transform(self.image) + return image, self.preprocess_txt(self.caption) + + +def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): + image_size = preprocess_fn.transforms[0].size + dataset = SyntheticDataset( + transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "synthetic": + return get_synthetic_dataset + elif dataset_type == "auto": + ext = data_path.split('.')[-1] + if ext in ['csv', 'tsv']: + return get_csv_dataset + elif ext in ['tar']: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extension {ext}.") + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, preprocess_fns, epoch=0, tokenizer=None): + preprocess_train, preprocess_val = preprocess_fns + data = {} + + if args.train_data or args.dataset_type == "synthetic": + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, preprocess_val, is_train=False, tokenizer=tokenizer) + + if args.imagenet_val is not None: + data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") + + if args.imagenet_v2 is not None: + data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") + + return data diff --git a/openclip/training/distributed.py b/openclip/training/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..268a6c7ad75a9ef29c72801dbf59d606f3318a59 --- /dev/null +++ b/openclip/training/distributed.py @@ -0,0 +1,137 @@ +import os + +import torch +import torch.distributed as dist + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): + return True + else: + return False + + +def is_using_distributed(): + if 'WORLD_SIZE' in os.environ: + return int(os.environ['WORLD_SIZE']) > 1 + if 'SLURM_NTASKS' in os.environ: + return int(os.environ['SLURM_NTASKS']) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + args.local_rank = int(hvd.local_rank()) + args.rank = hvd.rank() + args.world_size = hvd.size() + args.distributed = True + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + elif is_using_distributed(): + if 'SLURM_PROCID' in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'cuda:%d' % args.local_rank + else: + device = 'cuda:0' + torch.cuda.set_device(device) + else: + device = 'cpu' + args.device = device + device = torch.device(device) + return device + + +def broadcast_object(args, obj, src=0): + # broadcast a pickle-able python object from rank-0 to all ranks + if args.horovod: + return hvd.broadcast_object(obj, root_rank=src) + else: + if args.rank == src: + objects = [obj] + else: + objects = [None] + dist.broadcast_object_list(objects, src=src) + return objects[0] + + +def all_gather_object(args, obj, dst=0): + # gather a pickle-able python object across all ranks + if args.horovod: + return hvd.allgather_object(obj) + else: + objects = [None for _ in range(args.world_size)] + dist.all_gather_object(objects, obj) + return objects diff --git a/openclip/training/file_utils.py b/openclip/training/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..395cf7df0acc164c6851f17834d793f5852d4605 --- /dev/null +++ b/openclip/training/file_utils.py @@ -0,0 +1,83 @@ +import logging +import os +import multiprocessing +import subprocess +import time +import fsspec +import torch +from tqdm import tqdm + +def remote_sync_s3(local_dir, remote_dir): + # skip epoch_latest which can change during sync. + result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.returncode != 0: + logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") + return False + + logging.info(f"Successfully synced with S3 bucket") + return True + +def remote_sync_fsspec(local_dir, remote_dir): + # FIXME currently this is slow and not recommended. Look into speeding up. + a = fsspec.get_mapper(local_dir) + b = fsspec.get_mapper(remote_dir) + + for k in a: + # skip epoch_latest which can change during sync. + if 'epoch_latest.pt' in k: + continue + + logging.info(f'Attempting to sync {k}') + if k in b and len(a[k]) == len(b[k]): + logging.debug(f'Skipping remote sync for {k}.') + continue + + try: + logging.info(f'Successful sync for {k}.') + b[k] = a[k] + except Exception as e: + logging.info(f'Error during remote sync for {k}: {e}') + return False + + return True + +def remote_sync(local_dir, remote_dir, protocol): + logging.info('Starting remote sync.') + if protocol == 's3': + return remote_sync_s3(local_dir, remote_dir) + elif protocol == 'fsspec': + return remote_sync_fsspec(local_dir, remote_dir) + else: + logging.error('Remote protocol not known') + return False + +def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): + while True: + time.sleep(sync_every) + remote_sync(local_dir, remote_dir, protocol) + +def start_sync_process(sync_every, local_dir, remote_dir, protocol): + p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) + return p + +# Note: we are not currently using this save function. +def pt_save(pt_obj, file_path): + of = fsspec.open(file_path, "wb") + with of as f: + torch.save(pt_obj, file_path) + +def pt_load(file_path, map_location=None): + if file_path.startswith('s3'): + logging.info('Loading remote checkpoint, which may take a bit.') + of = fsspec.open(file_path, "rb") + with of as f: + out = torch.load(f, map_location=map_location) + return out + +def check_exists(file_path): + try: + with fsspec.open(file_path): + pass + except FileNotFoundError: + return False + return True diff --git a/openclip/training/imagenet_zeroshot_data.py b/openclip/training/imagenet_zeroshot_data.py new file mode 100644 index 0000000000000000000000000000000000000000..27abd8bf24ebe077a73e8496576d949d8bb16f69 --- /dev/null +++ b/openclip/training/imagenet_zeroshot_data.py @@ -0,0 +1,254 @@ + + +imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] + + + + + +openai_imagenet_template = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] diff --git a/openclip/training/logger.py b/openclip/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9abed92568d459cbc8d6094ae3901935d89621 --- /dev/null +++ b/openclip/training/logger.py @@ -0,0 +1,26 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + hostname = socket.gethostname() + formatter = logging.Formatter( + f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') + else: + formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) + diff --git a/openclip/training/main.py b/openclip/training/main.py new file mode 100644 index 0000000000000000000000000000000000000000..f70c9f9530559c08ef3a4baa525edb44f1b0d619 --- /dev/null +++ b/openclip/training/main.py @@ -0,0 +1,470 @@ +import glob +import logging +import os +import re +import subprocess +import sys +import random +from datetime import datetime + +import numpy as np +import torch +from torch import optim +from torch.cuda.amp import GradScaler + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss +from training.data import get_data +from training.distributed import is_master, init_distributed_device, broadcast_object +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr, const_lr, const_lr_cooldown +from training.train import train_one_epoch, evaluate +from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync + + +LATEST_CHECKPOINT_NAME = "epoch_latest.pt" + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def get_latest_checkpoint(path: str, remote : bool): + # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders + if remote: + result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + print(result) + if result.returncode == 1: + return None + checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] + else: + checkpoints = glob.glob(path + '**/*.pt', recursive=True) + if checkpoints: + checkpoints = sorted(checkpoints, key=natural_key) + return checkpoints[-1] + return None + + +def main(args): + args = parse_args(args) + + if torch.cuda.is_available(): + # This enables tf32 on Ampere GPUs which is only 8% slower than + # float16 and almost as accurate as float32 + # This was a default in pytorch until 1.12 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # fully initialize distributed device environment + device = init_distributed_device(args) + + # get the name of the experiments + if args.name is None: + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + model_name_safe = args.model.replace('/', '-') + date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + if args.distributed: + # sync date_str from master to all ranks + date_str = broadcast_object(args, date_str) + args.name = '-'.join([ + date_str, + f"model_{model_name_safe}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ]) + + resume_latest = args.resume == 'latest' + log_base_path = os.path.join(args.logs, args.name) + args.log_path = None + if is_master(args, local=args.log_local): + os.makedirs(log_base_path, exist_ok=True) + log_filename = f'out-{args.rank}' if args.log_local else 'out.log' + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path) and not resume_latest: + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Setup text logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # Setup wandb, tensorboard, checkpoint logging + args.wandb = 'wandb' in args.report_to or 'all' in args.report_to + args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to + args.checkpoint_path = os.path.join(log_base_path, "checkpoints") + if is_master(args): + args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = '' + + if resume_latest: + resume_from = None + checkpoint_path = args.checkpoint_path + # If using remote_sync, need to check the remote instead of the local checkpoints folder. + if args.remote_sync is not None: + checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") + if args.save_most_recent: + print('Error. Cannot use save-most-recent with remote_sync and resume latest.') + return -1 + if args.remote_sync_protocol != 's3': + print('Error. Sync protocol not supported when using resume latest.') + return -1 + if is_master(args): + # Checking for existing checkpoint via master rank only. It is possible for + # different rank processes to see different files if a shared file-system is under + # stress, however it's very difficult to fully work around such situations. + if args.save_most_recent: + # if --save-most-recent flag is set, look for latest at a fixed filename + resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) + if not os.path.exists(resume_from): + # If no latest checkpoint has been saved yet, don't try to resume + resume_from = None + else: + # otherwise, list checkpoint dir contents and pick the newest checkpoint + resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) + if resume_from: + logging.info(f'Found latest resume checkpoint at {resume_from}.') + else: + logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') + if args.distributed: + # sync found checkpoint path to all ranks + resume_from = broadcast_object(args, resume_from) + args.resume = resume_from + + if args.copy_codebase: + copy_codebase(args) + + # start the sync proces if remote-sync is not None + remote_sync_process = None + if is_master(args) and args.remote_sync is not None: + # first make sure it works + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('remote sync successful.') + else: + logging.info('Error: remote sync failed. Exiting.') + return -1 + # if all looks good, start a process to do this every args.remote_sync_frequency seconds + remote_sync_process = start_sync_process( + args.remote_sync_frequency, + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + remote_sync_process.start() + + if args.precision == 'fp16': + logging.warning( + 'It is recommended to use AMP mixed-precision instead of FP16. ' + 'FP16 support needs further verification and tuning, especially for train.') + + if args.horovod: + logging.info( + f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + elif args.distributed: + logging.info( + f'Running in distributed mode with multiple processes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + else: + logging.info(f'Running with a single process. Device {args.device}.') + + dist_model = None + args.distill = args.distill_model is not None and args.distill_pretrained is not None + if args.distill: + #FIXME: support distillation with grad accum. + assert args.accum_freq == 1 + #FIXME: support distillation with coca. + assert 'coca' not in args.model.lower() + + if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: + # arg is nargs, single (square) image size list -> int + args.force_image_size = args.force_image_size[0] + random_seed(args.seed, 0) + model, preprocess_train, preprocess_val = create_model_and_transforms( + args.model, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + force_custom_text=args.force_custom_text, + force_patch_dropout=args.force_patch_dropout, + force_image_size=args.force_image_size, + pretrained_image=args.pretrained_image, + image_mean=args.image_mean, + image_std=args.image_std, + aug_cfg=args.aug_cfg, + output_dict=True, + ) + if args.distill: + # FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms. + dist_model, _, _ = create_model_and_transforms( + args.distill_model, + args.distill_pretrained, + device=device, + precision=args.precision, + output_dict=True, + ) + + random_seed(args.seed, args.rank) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if args.lock_image: + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + model.lock_image_tower( + unlocked_groups=args.lock_image_unlocked_groups, + freeze_bn_stats=args.lock_image_freeze_bn_stats) + if args.lock_text: + model.lock_text_tower( + unlocked_layers=args.lock_text_unlocked_layers, + freeze_layer_norm=args.lock_text_freeze_layer_norm) + + if args.grad_checkpointing: + model.set_grad_checkpointing() + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + + if args.distill: + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) + + # create optimizer and scaler + optimizer = None + scaler = None + + if args.train_data or args.dataset_type == "synthetic": + assert not args.trace, 'Cannot train with traced model' + + exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + optimizer = optim.AdamW( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + if args.horovod: + optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + checkpoint = pt_load(args.resume, map_location='cpu') + if 'epoch' in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + model.load_state_dict(sd) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and 'scaler' in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + + # initialize datasets + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + assert len(data), 'At least one train or eval dataset must be specified.' + + # create scheduler if train + scheduler = None + if 'train' in data and optimizer is not None: + total_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs + if args.lr_scheduler == "cosine": + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const": + scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const-cooldown": + assert args.epochs_cooldown is not None,\ + "Please specify the number of cooldown epochs for this lr schedule." + cooldown_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown + scheduler = const_lr_cooldown( + optimizer, args.lr, args.warmup, total_steps, + cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) + else: + logging.error( + f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') + exit(1) + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, 'Please install wandb.' + logging.debug('Starting wandb.') + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project=args.wandb_project_name, + name=args.name, + id=args.name, + notes=args.wandb_notes, + tags=[], + resume='auto' if args.resume == "latest" else None, + config=vars(args), + ) + if args.debug: + wandb.watch(model, log='all') + wandb.save(params_file) + logging.debug('Finished loading wandb.') + + if 'train' not in data: + evaluate(model, data, start_epoch, args, writer) + return + + loss = create_loss(args) + + for epoch in range(start_epoch, args.epochs): + if is_master(args): + logging.info(f'Start epoch {epoch}') + + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) + completed_epoch = epoch + 1 + + if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): + evaluate(model, data, completed_epoch, args, writer) + + # Saving checkpoints. + if args.save_logs: + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.delete_previous_checkpoint: + previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") + if os.path.exists(previous_checkpoint): + os.remove(previous_checkpoint) + + if args.save_most_recent: + # try not to corrupt the latest checkpoint if save fails + tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") + latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) + torch.save(checkpoint_dict, tmp_save_path) + os.replace(tmp_save_path, latest_save_path) + + if args.wandb and is_master(args): + wandb.finish() + + # run a final sync. + if remote_sync_process is not None: + logging.info('Final remote sync.') + remote_sync_process.terminate() + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('Final remote sync successful.') + else: + logging.info('Final remote sync failed.') + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/openclip/training/params.py b/openclip/training/params.py new file mode 100644 index 0000000000000000000000000000000000000000..36c693bc76e9e791d6a934ad564be84e1bbe1a4e --- /dev/null +++ b/openclip/training/params.py @@ -0,0 +1,435 @@ +import argparse +import ast + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + kw = {} + for value in values: + key, value = value.split('=') + try: + kw[key] = ast.literal_eval(value) + except ValueError: + kw[key] = str(value) # fallback to string (avoid need to escape on command line) + setattr(namespace, self.dest, kw) + + +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", + ) + parser.add_argument( + "--train-data-upsampling-factors", + type=str, + default=None, + help=( + "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. " + "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) " + "By default, datapoints are sampled uniformly regardless of the dataset sizes." + ) + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to file(s) with validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "synthetic", "auto"], + default="auto", + help="Which type of dataset to process." + ) + parser.add_argument( + "--dataset-resampled", + default=False, + action="store_true", + help="Whether to use sampling with replacement for webdataset shard selection." + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use." + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths." + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions." + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of dataloader workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument( + "--epochs-cooldown", type=int, default=None, + help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.") + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-scheduler", + type=str, + default='cosine', + help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", + ) + parser.add_argument( + "--lr-cooldown-end", type=float, default=0.0, + help="End learning rate for cooldown schedule. Default: 0" + ) + parser.add_argument( + "--lr-cooldown-power", type=float, default=1.0, + help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)" + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], + default="amp", + help="Floating point precision." + ) + parser.add_argument( + "--model", + type=str, + default="RN50", + help="Name of the vision backbone to use.", + ) + parser.add_argument( + "--pretrained", + default='', + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action='store_true', + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action='store_true', + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action='store_true', + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) + parser.add_argument( + "--grad-checkpointing", + default=False, + action='store_true', + help="Enable gradient checkpointing.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather" + ) + parser.add_argument( + '--force-image-size', type=int, nargs='+', default=None, + help='Override default image size' + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action='store_true', + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--force-patch-dropout", + default=None, + type=float, + help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper", + ) + parser.add_argument( + "--force-custom-text", + default=False, + action='store_true', + help="Force use of CustomTextCLIP model (separate text-tower).", + ) + parser.add_argument( + "--torchscript", + default=False, + action='store_true', + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action='store_true', + help="torch.jit.trace the model for inference / eval only", + ) + parser.add_argument( + "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default='', + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" + ) + parser.add_argument( + "--wandb-notes", + default='', + type=str, + help="Notes if logging with wandb" + ) + parser.add_argument( + "--wandb-project-name", + type=str, + default='open-clip', + help="Name of the project if logging with wandb.", + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged." + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log directory, and execute from there." + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training." + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action='store_true', + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." + ) + parser.add_argument( + "--seed", type=int, default=0, help="Default random seed." + ) + parser.add_argument( + "--grad-clip-norm", type=float, default=None, help="Gradient clip." + ) + parser.add_argument( + "--lock-text", + default=False, + action='store_true', + help="Lock full text tower by disabling gradients.", + ) + parser.add_argument( + "--lock-text-unlocked-layers", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-text-freeze-layer-norm", + default=False, + action='store_true', + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--log-every-n-steps", + type=int, + default=100, + help="Log every n steps to tensorboard/console/wandb.", + ) + parser.add_argument( + "--coca-caption-loss-weight", + type=float, + default=2.0, + help="Weight assigned to caption loss in CoCa." + ) + parser.add_argument( + "--coca-contrastive-loss-weight", + type=float, + default=1.0, + help="Weight assigned to contrastive loss when training CoCa." + ) + parser.add_argument( + "--remote-sync", + type=str, + default=None, + help="Optinoally sync with a remote path specified by this arg", + ) + parser.add_argument( + "--remote-sync-frequency", + type=int, + default=300, + help="How frequently to sync to a remote directly if --remote-sync is not None.", + ) + parser.add_argument( + "--remote-sync-protocol", + choices=["s3", "fsspec"], + default="s3", + help="How to do the remote sync backup if --remote-sync is not None.", + ) + parser.add_argument( + "--delete-previous-checkpoint", + default=False, + action="store_true", + help="If true, delete previous checkpoint after storing a new one." + ) + parser.add_argument( + "--distill-model", + default=None, + help='Which model arch to distill from, if any.' + ) + parser.add_argument( + "--distill-pretrained", + default=None, + help='Which pre-trained weights to distill from, if any.' + ) + args = parser.parse_args(args) + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.model) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/openclip/training/precision.py b/openclip/training/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/openclip/training/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/openclip/training/profile.py b/openclip/training/profile.py new file mode 100644 index 0000000000000000000000000000000000000000..f10372cdef306e5e199db432b23062df1c098cf9 --- /dev/null +++ b/openclip/training/profile.py @@ -0,0 +1,158 @@ +import argparse + +import torch +import open_clip +import pandas as pd +from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis + + +parser = argparse.ArgumentParser(description='OpenCLIP Profiler') + +# benchmark specific args +parser.add_argument('--model', metavar='NAME', default='', + help='model(s) to profile') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for results') + + +def profile_fvcore( + model, + image_input_size=(3, 224, 224), + text_input_size=(77,), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) + aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def profile_fvcore_text( + model, + text_input_size=(77,), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device = next(model.parameters()).device + example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def profile_fvcore_image( + model, + image_input_size=(3, 224, 224), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def count_params(model): + return sum([m.numel() for m in model.parameters()]) + + +def profile_model(model_name): + model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) + model.eval() + if torch.cuda.is_available(): + model = model.cuda() + + if isinstance(model.visual.image_size, (tuple, list)): + image_input_size = (3,) + tuple(model.visual.image_size[-2:]) + else: + image_input_size = (3, model.visual.image_size, model.visual.image_size) + text_input_size = (77,) + + results = {} + results['model'] = model_name + results['image_size'] = image_input_size[1] + + model_cfg = open_clip.get_model_config(model_name) + if model_cfg: + vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) + text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) + results['image_width'] = int(vision_cfg.width) + results['text_width'] = int(text_cfg.width) + results['embed_dim'] = int(model_cfg['embed_dim']) + else: + results['image_width'] = 0 + results['text_width'] = 0 + results['embed_dim'] = 0 + + retries = 2 + while retries: + retries -= 1 + try: + macs, acts = profile_fvcore( + model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) + + image_macs, image_acts = profile_fvcore_image( + model.visual, image_input_size=image_input_size, force_cpu=not retries) + + text_macs, text_acts = profile_fvcore_text( + model.text, text_input_size=text_input_size, force_cpu=not retries) + + results['gmacs'] = round(macs / 1e9, 2) + results['macts'] = round(acts / 1e6, 2) + results['mparams'] = round(count_params(model) / 1e6, 2) + results['image_gmacs'] = round(image_macs / 1e9, 2) + results['image_macts'] = round(image_acts / 1e6, 2) + results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) + results['text_gmacs'] = round(text_macs / 1e9, 2) + results['text_macts'] = round(text_acts / 1e6, 2) + results['text_mparams'] = round(count_params(model.text) / 1e6, 2) + except RuntimeError as e: + pass + return results + + +def main(): + args = parser.parse_args() + + # FIXME accept a text file name to allow lists of models in txt/csv + if args.model == 'all': + parsed_model = open_clip.list_models() + else: + parsed_model = args.model.split(',') + + results = [] + for m in parsed_model: + row = profile_model(m) + results.append(row) + + df = pd.DataFrame(results, columns=results[0].keys()) + df = df.sort_values('gmacs') + print(df) + if args.results_file: + df.to_csv(args.results_file, index=False) + + +if __name__ == '__main__': + main() diff --git a/openclip/training/scheduler.py b/openclip/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..fba76fcf1720b11d136a5ab6d3a58ab2fbe42f74 --- /dev/null +++ b/openclip/training/scheduler.py @@ -0,0 +1,53 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def const_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + lr = base_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster + + +def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): + def _lr_adjuster(step): + start_cooldown_step = steps - cooldown_steps + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + if step < start_cooldown_step: + lr = base_lr + else: + e = step - start_cooldown_step + es = steps - start_cooldown_step + # linear decay if power == 1; polynomial decay otherwise; + decay = (1 - (e/es)) ** cooldown_power + lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster diff --git a/openclip/training/train.py b/openclip/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a140f9c83208c6fe95d07d37fe0a075e3d456b --- /dev/null +++ b/openclip/training/train.py @@ -0,0 +1,361 @@ +import json +import logging +import math +import os +import time + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.parallel.distributed import DistributedDataParallel + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import get_cast_dtype, CLIP, CustomTextCLIP +from .distributed import is_master +from .zero_shot import zero_shot_eval +from .precision import get_autocast + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def postprocess_clip_output(model_out): + return { + "image_features": model_out[0], + "text_features": model_out[1], + "logit_scale": model_out[2] + } + +def unwrap_model(model): + if hasattr(model, 'module'): + return model.module + else: + return model + + +def backward(total_loss, scaler): + if scaler is not None: + scaler.scale(total_loss).backward() + else: + total_loss.backward() + + +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): + device = torch.device(args.device) + autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) + + + model.train() + if args.distill: + dist_model.eval() + + data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch + dataloader = data['train'].dataloader + num_batches_per_epoch = dataloader.num_batches // args.accum_freq + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + if args.accum_freq > 1: + accum_images, accum_texts, accum_features = [], [], {} + + losses_m = {} + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + for i, batch in enumerate(dataloader): + i_accum = i // args.accum_freq + step = num_batches_per_epoch * epoch + i_accum + + if not args.skip_scheduler: + scheduler(step) + + images, texts = batch + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) + texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + optimizer.zero_grad() + + if args.accum_freq == 1: + with autocast(): + model_out = model(images, texts) + logit_scale = model_out["logit_scale"] + if args.distill: + with torch.no_grad(): + dist_model_out = dist_model(images, texts) + model_out.update({f'dist_{k}' : v for k, v in dist_model_out.items()}) + losses = loss(**model_out, output_dict=True) + + total_loss = sum(losses.values()) + losses["loss"] = total_loss + + backward(total_loss, scaler) + else: + # First, cache the features without any gradient tracking. + with torch.no_grad(): + with autocast(): + model_out = model(images, texts) + model_out.pop("logit_scale") + for key, val in model_out.items(): + if key in accum_features: + accum_features[key].append(val) + else: + accum_features[key] = [val] + + accum_images.append(images) + accum_texts.append(texts) + + # If (i + 1) % accum_freq is not zero, move on to the next batch. + if ((i + 1) % args.accum_freq) > 0: + # FIXME this makes data time logging unreliable when accumulating + continue + + # Now, ready to take gradients for the last accum_freq batches. + # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives. + # Call backwards each time, but only step optimizer at the end. + optimizer.zero_grad() + for j in range(args.accum_freq): + images = accum_images[j] + texts = accum_texts[j] + with autocast(): + model_out = model(images, texts) + logit_scale = model_out.pop("logit_scale") + inputs = {} + for key, val in accum_features.items(): + accumulated = accum_features[key] + inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) + losses = loss(**inputs, logit_scale=logit_scale, output_dict=True) + del inputs + total_loss = sum(losses.values()) + losses["loss"] = total_loss + backward(total_loss, scaler) + + if scaler is not None: + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + if args.grad_clip_norm is not None: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + scaler.step(optimizer) + scaler.update() + else: + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + optimizer.step() + + # reset gradient accum, if enabled + if args.accum_freq > 1: + accum_images, accum_texts, accum_features = [], [], {} + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i_accum + 1 + if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): + batch_size = len(images) + num_samples = batch_count * batch_size * args.accum_freq * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + for key, val in losses.items(): + if key not in losses_m: + losses_m[key] = AverageMeter() + losses_m[key].update(val.item(), batch_size) + + logit_scale_scalar = logit_scale.item() + loss_log = " ".join( + [ + f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" + for loss_name, loss_m in losses_m.items() + ] + ) + samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val + samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "samples_per_second": samples_per_second, + "samples_per_second_per_gpu": samples_per_second_per_gpu, + "scale": logit_scale_scalar, + "lr": optimizer.param_groups[0]["lr"] + } + log_data.update({name:val.val for name,val in losses_m.items()}) + + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, 'Please install wandb.' + wandb.log({name: val, 'step': step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + metrics.update(zero_shot_metrics) + + autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) + + if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): + dataloader = data['val'].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_image_features @ all_text_features will blow up memory and compute very quickly + cumulative_loss = 0.0 + cumulative_gen_loss = 0.0 + all_image_features, all_text_features = [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + images, texts = batch + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) + texts = texts.to(device=device, non_blocking=True) + + with autocast(): + model_out = model(images, texts) + image_features = model_out["image_features"] + text_features = model_out["text_features"] + logit_scale = model_out["logit_scale"] + # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly + # however, system RAM is easily exceeded and compute time becomes problematic + all_image_features.append(image_features.cpu()) + all_text_features.append(text_features.cpu()) + logit_scale = logit_scale.mean() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + batch_size = images.shape[0] + labels = torch.arange(batch_size, device=device).long() + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + gen_loss = maybe_compute_generative_loss(model_out) + + cumulative_loss += total_loss * batch_size + num_samples += batch_size + if is_master(args) and (i % 100) == 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" + f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") + + if gen_loss is not None: + cumulative_gen_loss += gen_loss * batch_size + logging.info( + f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") + + val_metrics = get_clip_metrics( + image_features=torch.cat(all_image_features), + text_features=torch.cat(all_text_features), + logit_scale=logit_scale.cpu(), + ) + loss = cumulative_loss / num_samples + metrics.update( + {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + ) + if gen_loss is not None: + gen_loss = cumulative_gen_loss / num_samples + metrics.update({"val_generative_loss": gen_loss.item()}) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics + + +def get_clip_metrics(image_features, text_features, logit_scale): + metrics = {} + logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() + logits_per_text = logits_per_image.t().detach().cpu() + + logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[1] + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + + return metrics + + +def maybe_compute_generative_loss(model_out): + if "logits" in model_out and "labels" in model_out: + token_logits = model_out["logits"] + token_labels = model_out["labels"] + return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) diff --git a/openclip/training/zero_shot.py b/openclip/training/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..e5768b4a3ce26f0a9a12d8ee3a6d9490e778a78a --- /dev/null +++ b/openclip/training/zero_shot.py @@ -0,0 +1,93 @@ +import logging + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_cast_dtype, get_tokenizer +from .precision import get_autocast +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + tokenizer = get_tokenizer(args.model) + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenizer(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(args.device) + if cast_dtype is not None: + images = images.to(dtype=cast_dtype) + target = target.to(args.device) + + with autocast(): + # predict + if args.distributed and not args.horovod: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100. * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if 'imagenet-val' not in data and 'imagenet-v2' not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + + logging.info('Starting zero-shot imagenet.') + + logging.info('Building zero-shot classifier') + classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) + + logging.info('Using classifier') + results = {} + if 'imagenet-val' in data: + top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) + results['imagenet-zeroshot-val-top1'] = top1 + results['imagenet-zeroshot-val-top5'] = top5 + if 'imagenet-v2' in data: + top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) + results['imagenetv2-zeroshot-val-top1'] = top1 + results['imagenetv2-zeroshot-val-top5'] = top5 + + logging.info('Finished zero-shot imagenet.') + + return results diff --git a/output-images/body.png b/output-images/body.png new file mode 100644 index 0000000000000000000000000000000000000000..ecd859c38a2f17bf9caa92bdbc6d47e9fc9d1cef Binary files /dev/null and b/output-images/body.png differ diff --git a/output-images/depth.png b/output-images/depth.png new file mode 100644 index 0000000000000000000000000000000000000000..4310339d83c4fb6e8d89e21ea80db0104a3bd683 Binary files /dev/null and b/output-images/depth.png differ diff --git a/output-images/normal.png b/output-images/normal.png new file mode 100644 index 0000000000000000000000000000000000000000..79bfbae821f14e1ba7fc42f64feabb7015d2aaef Binary files /dev/null and b/output-images/normal.png differ diff --git a/output-images/rgb.png b/output-images/rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..18fa3025c5c35cd11b3dd593ae51fb590eefcfd3 Binary files /dev/null and b/output-images/rgb.png differ diff --git a/output-images/rgb2.png b/output-images/rgb2.png new file mode 100644 index 0000000000000000000000000000000000000000..0fd6cdb492279169cfedc9093cd2f1e25341713f --- /dev/null +++ b/output-images/rgb2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4d219735c5eb6daeb83ef3fceca890b03b53e4b0f70d4af46f6991744cf8647 +size 1595542 diff --git a/pipelines/__init__.py b/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipelines/__pycache__/__init__.cpython-310.pyc b/pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b083dbe23f3a294ace33a736515c6aad7cd2934b Binary files /dev/null and b/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/pipelines/__pycache__/__init__.cpython-38.pyc b/pipelines/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ac40192d36fd23d35fdb59c08d348d87ae30e1f Binary files /dev/null and b/pipelines/__pycache__/__init__.cpython-38.pyc differ diff --git a/pipelines/__pycache__/multicontrolnet.cpython-310.pyc b/pipelines/__pycache__/multicontrolnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e43d401e51b6e76b6961978f0f77bc22c5777008 Binary files /dev/null and b/pipelines/__pycache__/multicontrolnet.cpython-310.pyc differ diff --git a/pipelines/__pycache__/multicontrolnet.cpython-38.pyc b/pipelines/__pycache__/multicontrolnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..009cbe486e48b10781ef81df9dd14b3e293228ef Binary files /dev/null and b/pipelines/__pycache__/multicontrolnet.cpython-38.pyc differ diff --git a/pipelines/__pycache__/pipeline_controlnet_composer.cpython-310.pyc b/pipelines/__pycache__/pipeline_controlnet_composer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56f37992f67f6b15cc49a7283e026273001ef89b Binary files /dev/null and b/pipelines/__pycache__/pipeline_controlnet_composer.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_controlnet_composer_gating.cpython-310.pyc b/pipelines/__pycache__/pipeline_controlnet_composer_gating.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56097399944ce8dd9228229ddd267d7276592789 Binary files /dev/null and b/pipelines/__pycache__/pipeline_controlnet_composer_gating.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_controlnet_composer_sdxl.cpython-310.pyc b/pipelines/__pycache__/pipeline_controlnet_composer_sdxl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d28fb2ff9fc72b9f7317ac189b7f6f0e38f12db Binary files /dev/null and b/pipelines/__pycache__/pipeline_controlnet_composer_sdxl.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_controlnet_composer_sdxl.cpython-38.pyc b/pipelines/__pycache__/pipeline_controlnet_composer_sdxl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b9dc4d88e2577ecfd8a821079564e55f965ba34 Binary files /dev/null and b/pipelines/__pycache__/pipeline_controlnet_composer_sdxl.cpython-38.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_mb_downup.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_mb_downup.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a2afe3f6fa7a110d1fa56fdc8ac0614b35a32bb Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_mb_downup.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_mb_downup.cpython-38.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_mb_downup.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b9d0aabd764f3bb58b9ec63ad16bb07e6c91ef1 Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_mb_downup.cpython-38.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_spade.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_spade.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18d839ec8a1699ffdaa7ae45642e7fe2490afaa Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_spade.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_spade.cpython-38.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_spade.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ea31bbed2db4bbed20c93ea4392ff6df89d38a Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_spade.cpython-38.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690d8227994c642d580166ae4743ef29998a1885 Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe3.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae66db5696a2a4ed50f7eba5d533dca30dd90f84 Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe3.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe4.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe4.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cad82cf7191fe5f175219623fa39a95f8cecb111 Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe4.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe5.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe5.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..763815d6867dd6cca5891856f8ea1e8b8831f52c Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_spade_timemoe5.cpython-310.pyc differ diff --git a/pipelines/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc b/pipelines/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d3983051cb4b8b5ddadc08223a3440e3847c15f Binary files /dev/null and b/pipelines/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc differ diff --git a/pipelines/multicontrolnet.py b/pipelines/multicontrolnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a58108f81828839f710ca2fd5eff3f07a4cea6b0 --- /dev/null +++ b/pipelines/multicontrolnet.py @@ -0,0 +1,186 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.models.controlnet_composer import ControlNetModel, ControlNetOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + # controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + idx = 0 + model_path_to_save = save_directory + for controlnet in self.nets: + controlnet.save_pretrained( + model_path_to_save, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + idx += 1 + model_path_to_save = model_path_to_save + f"_{idx}" + + @classmethod + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., + `./my_model_directory/controlnet`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + controlnets = [] + + # load controlnet and append to list until no controlnet directory exists anymore + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) + controlnets.append(controlnet) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") + + if len(controlnets) == 0: + raise ValueError( + f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(controlnets) diff --git a/pipelines/pipeline_controlnet_composer.py b/pipelines/pipeline_controlnet_composer.py new file mode 100644 index 0000000000000000000000000000000000000000..8edc88515e3b80e93992e8e08b6e690602a16c3c --- /dev/null +++ b/pipelines/pipeline_controlnet_composer.py @@ -0,0 +1,1017 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel +from diffusers.models.controlnet_composer import ControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image_list, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + for image in image_list: + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image_list, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + output = [] + for image in image_list: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + output.append(image) + + return output + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + List[List[torch.FloatTensor]], + List[List[PIL.Image.Image]], + List[List[np.ndarray]], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # prompt, + # image, + # callback_steps, + # negative_prompt, + # prompt_embeds, + # negative_prompt_embeds, + # controlnet_conditioning_scale, + # control_guidance_start, + # control_guidance_end, + # ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image_list=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image[0].shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image_list=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0][0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/pipelines/pipeline_controlnet_composer_gating.py b/pipelines/pipeline_controlnet_composer_gating.py new file mode 100644 index 0000000000000000000000000000000000000000..98570f69ab04e9ae2fa2a675dd50e9ab04b2b9f6 --- /dev/null +++ b/pipelines/pipeline_controlnet_composer_gating.py @@ -0,0 +1,1050 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel +from diffusers.models.controlnet_composer import ControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + gating_unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + gating_unet=gating_unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image_list, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + for image in image_list: + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image_list, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + output = [] + for image in image_list: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + output.append(image) + + return output + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + List[List[torch.FloatTensor]], + List[List[PIL.Image.Image]], + List[List[np.ndarray]], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + args=None, + batch=None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # prompt, + # image, + # callback_steps, + # negative_prompt, + # prompt_embeds, + # negative_prompt_embeds, + # controlnet_conditioning_scale, + # control_guidance_start, + # control_guidance_end, + # ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image_list=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image[0].shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image_list=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0][0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + midas_depth_latents = self.vae.encode(batch["midas_depth"].to(dtype=latents.dtype)).latent_dist.sample() + midas_depth_latents = midas_depth_latents * self.vae.config.scaling_factor + if args.normalize_dist: + midas_depth_latents = (midas_depth_latents - args.depth_mean) / args.depth_std * args.rgb_std + args.rgb_mean + midas_depth_latents_input = torch.cat([midas_depth_latents] * 2) if do_classifier_free_guidance else midas_depth_latents + + normal_latents = self.vae.encode(batch["normal"].to(dtype=latents.dtype)).latent_dist.sample() + normal_latents = normal_latents * self.vae.config.scaling_factor + if args.normalize_dist: + normal_latents = (normal_latents - args.normal_mean) / args.normal_std * args.rgb_std + args.rgb_mean + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + + whole_latents = self.vae.encode(batch["whole"].to(dtype=latents.dtype)).latent_dist.sample() + whole_latents = whole_latents * self.vae.config.scaling_factor + if args.normalize_dist: + whole_latents = (whole_latents - args.whole_mean) / args.whole_std * args.rgb_std + args.rgb_mean + whole_latents_input = torch.cat([whole_latents] * 2) if do_classifier_free_guidance else whole_latents + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + input_to_gating = torch.cat([control_model_input, midas_depth_latents_input, normal_latents_input, whole_latents_input], dim=1) + # print(input_to_gating.shape, t.shape, controlnet_prompt_embeds.shape) + gating_matrix = self.gating_unet( + input_to_gating, + t, + encoder_hidden_states=controlnet_prompt_embeds, + ).sample + # shape is (B, 3, H, W) + gating_matrix = F.softmax(gating_matrix, dim=1) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + gating_matrix=gating_matrix, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/pipelines/pipeline_controlnet_composer_sdxl.py b/pipelines/pipeline_controlnet_composer_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..64882704f5c84053752d543916ac850b07c773f9 --- /dev/null +++ b/pipelines/pipeline_controlnet_composer_sdxl.py @@ -0,0 +1,973 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from .multicontrolnet import MultiControlNetModel +from diffusers.models.controlnet_composer import ControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # To be updated when there's a useful ControlNet checkpoint + >>> # compatible with SDXL. + ``` +""" + + +class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + raise ValueError("MultiControlNet is not yet supported.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.watermark = StableDiffusionXLWatermarker() + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image_list, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + output = [] + for image in image_list: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + output.append(image) + + return output + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + List[List[torch.FloatTensor]], + List[List[PIL.Image.Image]], + List[List[np.ndarray]], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + denoising_end: Optional[float] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # prompt, + # image, + # callback_steps, + # negative_prompt, + # prompt_embeds, + # negative_prompt_embeds, + # controlnet_conditioning_scale, + # control_guidance_start, + # control_guidance_end, + # ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image_list=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 7.2 Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + + # # 8. Denoising loop + # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/pipelines/pipeline_stable_diffusion_mb_downup.py b/pipelines/pipeline_stable_diffusion_mb_downup.py new file mode 100644 index 0000000000000000000000000000000000000000..ea04e1abf124c925189b7e43e0d522b1d2c76bc6 --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_mb_downup.py @@ -0,0 +1,888 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, aes_watermark, dtype): + if self.args.off_wa: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size + aes_watermark) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + # h = 512, + # w = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + args=None, + batch=None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + aes_watermark: Tuple[float, float] = (600., 10.), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + self.args = args + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # H = torch.tensor([h * batch_size]).cuda() + # W = torch.tensor([w * batch_size]).cuda() + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + _, c, h, w = latents.shape + + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + # midas_depth_latents = latents.clone() + # normal_latents = latents.clone() + midas_depth_latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + normal_latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + if batch is None: + whole_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + whole_latents = self.vae.encode(batch["whole"].to(latents.dtype)).latent_dist.sample() + whole_latents = whole_latents * self.vae.config.scaling_factor + if args.normalize_dist: + whole_latents = (whole_latents - args.whole_mean) / args.whole_std * args.rgb_std + args.rgb_mean + whole_latents_input = torch.cat([whole_latents] * 2) if do_classifier_free_guidance else whole_latents + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, aes_watermark, dtype=prompt_embeds.dtype + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + midas_depth_latents_input = torch.cat([midas_depth_latents] * 2) if do_classifier_free_guidance else midas_depth_latents + midas_depth_latents_input = self.scheduler.scale_model_input(midas_depth_latents_input, t) + + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + normal_latents_input = self.scheduler.scale_model_input(normal_latents_input, t) + + _, c, h, w = latent_model_input.shape + + noisy_latents_with_cond = torch.cat([latent_model_input, whole_latents_input], dim=1) + + noisy_latents_list = [torch.cat([midas_depth_latents_input, whole_latents_input], dim=1), torch.cat([normal_latents_input, whole_latents_input], dim=1)] + + added_cond_kwargs = {"time_ids": add_time_ids} + + # predict the noise residual + + noise_pred = self.unet( + noisy_latents_with_cond, + noisy_latents_list, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + rgb_pred = noise_pred[:, :4] + midas_depth_pred = noise_pred[:, 4:8] + normal_pred = noise_pred[:, 8:12] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(rgb_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + midas_depth_latents = self.scheduler.step(midas_depth_pred, t, midas_depth_latents, **extra_step_kwargs, return_dict=False)[0] + normal_latents = self.scheduler.step(normal_pred, t, normal_latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if args.normalize_dist: + midas_depth_latents = (midas_depth_latents - args.rgb_mean) / args.rgb_std * args.depth_std + args.depth_mean + midas_depth_image = self.vae.decode(midas_depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if args.normalize_dist: + normal_latents = (normal_latents - args.rgb_mean) / args.rgb_std * args.normal_std + args.normal_mean + normal_image = self.vae.decode(normal_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + if args.normalize_dist: + midas_depth_latents = (midas_depth_latents - args.rgb_mean) / args.rgb_std * args.depth_std + args.depth_mean + midas_depth_image = midas_depth_latents + if args.normalize_dist: + normal_latents = (normal_latents - args.rgb_mean) / args.rgb_std * args.normal_std + args.normal_mean + normal_image = normal_latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + midas_depth_image = self.image_processor.postprocess(midas_depth_image, output_type=output_type, do_denormalize=do_denormalize) + normal_image = self.image_processor.postprocess(normal_image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + output_tuple = (image) + output_tuple = output_tuple + (midas_depth_image) + output_tuple = output_tuple + (normal_image) + return output_tuple + (has_nsfw_concept) + + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + output["images"] = image + output["midas_depth_image"] = midas_depth_image + output["normal_image"] = normal_image + + return output diff --git a/pipelines/pipeline_stable_diffusion_spade.py b/pipelines/pipeline_stable_diffusion_spade.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea03a9684136421a442c4f58904d7f94a795b72 --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_spade.py @@ -0,0 +1,1371 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + # h = 512, + # w = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + args=None, + batch=None, + depth_embedder=None, + midas_depth_embedder=None, + normal_embedder=None, + canny_embedder=None, + body_embedder=None, + face_embedder=None, + hand_embedder=None, + ldmk_embedder=None, + whole_embedder=None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # H = torch.tensor([h * batch_size]).cuda() + # W = torch.tensor([w * batch_size]).cuda() + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + _, c, h, w = latents.shape + + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if "depth" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + depth_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + depth_latents = self.vae.encode(batch["depth"].to(latents.dtype)).latent_dist.sample() + depth_latents = depth_latents * self.vae.config.scaling_factor + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "resize": + if batch is None: + depth_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + depth_latents = F.interpolate(batch['depth'], (h,w)) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + depth_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + depth_latents = depth_embedder(batch['depth']) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + else: + assert False, "unknown condition reshape type" + + if "depth" in args.noisy_cond: + # depth_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # depth_latents = depth_latents * self.scheduler.init_noise_sigma + depth_latents = latents.clone() + + if "midas_depth" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + midas_depth_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + midas_depth_latents = self.vae.encode(batch["midas_depth"].to(latents.dtype)).latent_dist.sample() + midas_depth_latents = midas_depth_latents * self.vae.config.scaling_factor + midas_depth_latents_input = torch.cat([midas_depth_latents] * 2) if do_classifier_free_guidance else midas_depth_latents + elif args.cond_reshape == "resize": + if batch is None: + midas_depth_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + midas_depth_latents = F.interpolate(batch['midas_depth'], (h,w)) + midas_depth_latents_input = torch.cat([midas_depth_latents] * 2) if do_classifier_free_guidance else midas_depth_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + midas_depth_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + midas_depth_latents = midas_depth_embedder(batch['midas_depth']) + midas_depth_latents_input = torch.cat([midas_depth_latents] * 2) if do_classifier_free_guidance else midas_depth_latents + else: + assert False, "unknown condition reshape type" + + if "midas_depth" in args.noisy_cond: + # midas_depth_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # midas_depth_latents = midas_depth_latents * self.scheduler.init_noise_sigma + midas_depth_latents = latents.clone() + + if "normal" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + normal_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + normal_latents = self.vae.encode(batch["normal"].to(latents.dtype)).latent_dist.sample() + normal_latents = normal_latents * self.vae.config.scaling_factor + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "resize": + if batch is None: + normal_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + normal_latents = F.interpolate(batch['normal'], (h,w)) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + normal_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + normal_latents = normal_embedder(batch['normal']) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + else: + assert False, "unknown condition reshape type" + + if "normal" in args.noisy_cond: + # normal_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # normal_latents = normal_latents * self.scheduler.init_noise_sigma + normal_latents = latents.clone() + + if "canny" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + canny_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + canny_latents = self.vae.encode(batch["canny"].to(latents.dtype)).latent_dist.sample() + canny_latents = canny_latents * self.vae.config.scaling_factor + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "resize": + if batch is None: + canny_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + canny_latents = F.interpolate(batch['canny'], (h,w)) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + canny_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + canny_latents = canny_embedder(batch['canny']) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + else: + assert False, "unknown condition reshape type" + + if "canny" in args.noisy_cond: + # canny_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # canny_latents = canny_latents * self.scheduler.init_noise_sigma + canny_latents = latents.clone() + + if "body" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + body_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + body_latents = self.vae.encode(batch["body"].to(latents.dtype)).latent_dist.sample() + body_latents = body_latents * self.vae.config.scaling_factor + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "resize": + if batch is None: + body_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + body_latents = F.interpolate(batch['body'], (h,w)) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + body_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + body_latents = body_embedder(batch['body']) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + else: + assert False, "unknown condition reshape type" + + if "body" in args.noisy_cond: + # body_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # body_latents = body_latents * self.scheduler.init_noise_sigma + body_latents = latents.clone() + + if "face" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + face_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + face_latents = self.vae.encode(batch["face"].to(latents.dtype)).latent_dist.sample() + face_latents = face_latents * self.vae.config.scaling_factor + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "resize": + if batch is None: + face_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + face_latents = F.interpolate(batch['face'], (h,w)) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + face_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + face_latents = face_embedder(batch['face']) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + else: + assert False, "unknown condition reshape type" + + if "face" in args.noisy_cond: + # face_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # face_latents = face_latents * self.scheduler.init_noise_sigma + face_latents = latents.clone() + + if "hand" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + hand_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + hand_latents = self.vae.encode(batch["hand"].to(latents.dtype)).latent_dist.sample() + hand_latents = hand_latents * self.vae.config.scaling_factor + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "resize": + if batch is None: + hand_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + hand_latents = F.interpolate(batch['hand'], (h,w)) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + hand_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + hand_latents = hand_embedder(batch['hand']) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + else: + assert False, "unknown condition reshape type" + + if "hand" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + hand_latents = latents.clone() + + if "whole" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + whole_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + whole_latents = self.vae.encode(batch["whole"].to(latents.dtype)).latent_dist.sample() + whole_latents = whole_latents * self.vae.config.scaling_factor + whole_latents_input = torch.cat([whole_latents] * 2) if do_classifier_free_guidance else whole_latents + elif args.cond_reshape == "resize": + if batch is None: + whole_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + whole_latents = F.interpolate(batch['whole'], (h,w)) + whole_latents_input = torch.cat([whole_latents] * 2) if do_classifier_free_guidance else whole_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + whole_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + whole_latents = whole_embedder(batch['whole']) + whole_latents_input = torch.cat([whole_latents] * 2) if do_classifier_free_guidance else whole_latents + else: + assert False, "unknown condition reshape type" + + if "whole" in args.noisy_cond: + # whole_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # whole_latents = whole_latents * self.scheduler.init_noise_sigma + whole_latents = latents.clone() + + if "ldmk" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + ldmk_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + ldmk_latents = self.vae.encode(batch["ldmk"].to(latents.dtype)).latent_dist.sample() + ldmk_latents = ldmk_latents * self.vae.config.scaling_factor + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "resize": + if batch is None: + ldmk_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + ldmk_latents = F.interpolate(batch['ldmk'], (h,w)) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + ldmk_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + ldmk_latents = ldmk_embedder(batch['ldmk']) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + else: + assert False, "unknown condition reshape type" + + if "ldmk" in args.noisy_cond: + # ldmk_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # ldmk_latents = ldmk_latents * self.scheduler.init_noise_sigma + ldmk_latents = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if "depth" in args.noisy_cond: + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + depth_latents_input = self.scheduler.scale_model_input(depth_latents_input, t) + + if "midas_depth" in args.noisy_cond: + midas_depth_latents_input = torch.cat([midas_depth_latents] * 2) if do_classifier_free_guidance else midas_depth_latents + midas_depth_latents_input = self.scheduler.scale_model_input(midas_depth_latents_input, t) + + if "normal" in args.noisy_cond: + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + normal_latents_input = self.scheduler.scale_model_input(normal_latents_input, t) + + if "canny" in args.noisy_cond: + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + canny_latents_input = self.scheduler.scale_model_input(canny_latents_input, t) + + if "body" in args.noisy_cond: + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + body_latents_input = self.scheduler.scale_model_input(body_latents_input, t) + + if "face" in args.noisy_cond: + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + face_latents_input = self.scheduler.scale_model_input(face_latents_input, t) + + if "hand" in args.noisy_cond: + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + hand_latents_input = self.scheduler.scale_model_input(hand_latents_input, t) + + if "whole" in args.noisy_cond: + whole_latents_input = torch.cat([whole_latents] * 2) if do_classifier_free_guidance else whole_latents + whole_latents_input = self.scheduler.scale_model_input(whole_latents_input, t) + + if "ldmk" in args.noisy_cond: + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + ldmk_latents_input = self.scheduler.scale_model_input(ldmk_latents_input, t) + + _, c, h, w = latent_model_input.shape + + if args.cond_inject == "concat": + latent_model_input = torch.cat([latent_model_input, depth_latents_input], dim=1) if "depth" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, midas_depth_latents_input], dim=1) if "midas_depth" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, normal_latents_input], dim=1) if "normal" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, canny_latents_input], dim=1) if "canny" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, body_latents_input], dim=1) if "body" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, face_latents_input], dim=1) if "face" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, hand_latents_input], dim=1) if "hand" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, whole_latents_input], dim=1) if "whole" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, ldmk_latents_input], dim=1) if "ldmk" in args.cond_type else latent_model_input + elif args.cond_inject == "sum": + if len(args.cond_type) == 0: + pass + else: + if args.cond_reshape == "vae": + channel_dim = 4 + elif args.cond_reshape == "resize": + channel_dim = 3 + elif args.cond_reshape == "learn_conv": + channel_dim = args.embedder_channel + sum_latents = torch.zeros((latent_model_input.shape[0], channel_dim, h, w)).to(self.unet.device) + sum_latents = sum_latents + depth_latents_input if "depth" in args.cond_type else sum_latents + sum_latents = sum_latents + midas_depth_latents_input if "midas_depth" in args.cond_type else sum_latents + sum_latents = sum_latents + normal_latents_input if "normal" in args.cond_type else sum_latents + sum_latents = sum_latents + canny_latents_input if "canny" in args.cond_type else sum_latents + sum_latents = sum_latents + body_latents_input if "body" in args.cond_type else sum_latents + sum_latents = sum_latents + face_latents_input if "face" in args.cond_type else sum_latents + sum_latents = sum_latents + hand_latents_input if "hand" in args.cond_type else sum_latents + sum_latents = sum_latents + whole_latents_input if "whole" in args.cond_type else sum_latents + latent_model_input = torch.cat([latent_model_input, sum_latents], dim=1) + + added_cond_kwargs = {"time_ids": add_time_ids} + + # predict the noise residual + if args.cond_inject == "spade": + if batch is None: + num_cond = 0 + if "depth" in args.cond_type: num_cond += 1 + if "midas_depth" in args.cond_type: num_cond += 1 + if "normal" in args.cond_type: num_cond += 1 + if "canny" in args.cond_type: num_cond += 1 + if "body" in args.cond_type: num_cond += 1 + if "face" in args.cond_type: num_cond += 1 + if "hand" in args.cond_type: num_cond += 1 + if "whole" in args.cond_type: num_cond += 1 + if "ldmk" in args.cond_type: num_cond += 1 + label_channels = num_cond * 3 + structural_cond = torch.zeros((batch_size, label_channels, h, w)).to(self.unet.device) + else: + structural_cond = [] + if "depth" in args.cond_type: + structural_cond.append(batch["depth"]) + if "midas_depth" in args.cond_type: + structural_cond.append(batch["midas_depth"]) + if "normal" in args.cond_type: + structural_cond.append(batch["normal"]) + if "canny" in args.cond_type: + structural_cond.append(batch["canny"]) + if "body" in args.cond_type: + structural_cond.append(batch["body"]) + if "face" in args.cond_type: + structural_cond.append(batch["face"]) + if "hand" in args.cond_type: + structural_cond.append(batch["hand"]) + if "whole" in args.cond_type: + structural_cond.append(batch["whole"]) + if "ldmk" in args.cond_type: + structural_cond.append(batch["ldmk"]) + structural_cond = torch.cat(structural_cond, dim=1) + structural_cond = torch.cat([structural_cond] * 2) if do_classifier_free_guidance else structural_cond + noise_pred = self.unet( + latent_model_input, + structural_cond, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if noise_pred.shape[1] > 4: + cond_pred = noise_pred[:, 4:] + noise_pred = noise_pred[:, :4] + if "depth" in args.cond_type and "depth" in args.pred_cond: + depth_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "midas_depth" in args.cond_type and "midas_depth" in args.pred_cond: + midas_depth_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "normal" in args.cond_type and "normal" in args.pred_cond: + normal_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "canny" in args.cond_type and "canny" in args.pred_cond: + canny_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "body" in args.cond_type and "body" in args.pred_cond: + body_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "face" in args.cond_type and "face" in args.pred_cond: + face_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "hand" in args.cond_type and "hand" in args.pred_cond: + hand_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "whole" in args.cond_type and "whole" in args.pred_cond: + whole_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if "depth" in args.noisy_cond: + depth_latents = self.scheduler.step(depth_pred, t, depth_latents, **extra_step_kwargs, return_dict=False)[0] + if "midas_depth" in args.noisy_cond: + midas_depth_latents = self.scheduler.step(midas_depth_pred, t, midas_depth_latents, **extra_step_kwargs, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_latents = self.scheduler.step(normal_pred, t, normal_latents, **extra_step_kwargs, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_latents = self.scheduler.step(canny_pred, t, canny_latents, **extra_step_kwargs, return_dict=False)[0] + if "body" in args.noisy_cond: + body_latents = self.scheduler.step(body_pred, t, body_latents, **extra_step_kwargs, return_dict=False)[0] + if "face" in args.noisy_cond: + face_latents = self.scheduler.step(face_pred, t, face_latents, **extra_step_kwargs, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_latents = self.scheduler.step(hand_pred, t, hand_latents, **extra_step_kwargs, return_dict=False)[0] + if "whole" in args.noisy_cond: + whole_latents = self.scheduler.step(whole_pred, t, whole_latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if "depth" in args.pred_cond: + if "depth" in args.noisy_cond: + depth_image = self.vae.decode(depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + depth_image = self.vae.decode(depth_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "midas_depth" in args.pred_cond: + if "midas_depth" in args.noisy_cond: + midas_depth_image = self.vae.decode(midas_depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + midas_depth_image = self.vae.decode(midas_depth_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "normal" in args.pred_cond: + if "normal" in args.noisy_cond: + normal_image = self.vae.decode(normal_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + normal_image = self.vae.decode(normal_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "canny" in args.pred_cond: + if "canny" in args.noisy_cond: + canny_image = self.vae.decode(canny_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + canny_image = self.vae.decode(canny_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "body" in args.pred_cond: + if "body" in args.noisy_cond: + body_image = self.vae.decode(body_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + body_image = self.vae.decode(body_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "face" in args.pred_cond: + if "face" in args.noisy_cond: + face_image = self.vae.decode(face_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + face_image = self.vae.decode(face_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "hand" in args.pred_cond: + if "hand" in args.noisy_cond: + hand_image = self.vae.decode(hand_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + hand_image = self.vae.decode(hand_pred / self.vae.config.scaling_factor, return_dict=False)[0] + if "whole" in args.pred_cond: + if "whole" in args.noisy_cond: + whole_image = self.vae.decode(whole_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + whole_image = self.vae.decode(whole_pred / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + if "depth" in args.pred_cond: + if "depth" in args.noisy_cond: + depth_image = depth_latents + else: + depth_image = depth_pred + if "midas_depth" in args.pred_cond: + if "midas_depth" in args.noisy_cond: + midas_depth_image = midas_depth_latents + else: + midas_depth_image = midas_depth_pred + if "normal" in args.pred_cond: + if "normal" in args.noisy_cond: + normal_image = normal_latents + else: + normal_image = normal_pred + if "canny" in args.pred_cond: + if "canny" in args.noisy_cond: + canny_image = canny_latents + else: + canny_image = canny_pred + if "body" in args.pred_cond: + if "body" in args.noisy_cond: + body_image = body_latents + else: + body_image = body_pred + if "face" in args.pred_cond: + if "face" in args.noisy_cond: + face_image = face_latents + else: + face_image = face_pred + if "hand" in args.pred_cond: + if "hand" in args.noisy_cond: + hand_image = hand_latents + else: + hand_image = hand_pred + if "whole" in args.pred_cond: + if "whole" in args.noisy_cond: + whole_image = whole_latents + else: + whole_image = whole_pred + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + if "depth" in args.pred_cond: + depth_image = self.image_processor.postprocess(depth_image, output_type=output_type, do_denormalize=do_denormalize) + if "midas_depth" in args.pred_cond: + midas_depth_image = self.image_processor.postprocess(midas_depth_image, output_type=output_type, do_denormalize=do_denormalize) + if "normal" in args.pred_cond: + normal_image = self.image_processor.postprocess(normal_image, output_type=output_type, do_denormalize=do_denormalize) + if "canny" in args.pred_cond: + canny_image = self.image_processor.postprocess(canny_image, output_type=output_type, do_denormalize=do_denormalize) + if "body" in args.pred_cond: + body_image = self.image_processor.postprocess(body_image, output_type=output_type, do_denormalize=do_denormalize) + if "face" in args.pred_cond: + face_image = self.image_processor.postprocess(face_image, output_type=output_type, do_denormalize=do_denormalize) + if "hand" in args.pred_cond: + hand_image = self.image_processor.postprocess(hand_image, output_type=output_type, do_denormalize=do_denormalize) + if "whole" in args.pred_cond: + whole_image = self.image_processor.postprocess(whole_image, output_type=output_type, do_denormalize=do_denormalize) + + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + output_tuple = (image) + if "depth" in args.pred_cond: + output_tuple = output_tuple + (depth_image) + if "midas_depth" in args.pred_cond: + output_tuple = output_tuple + (midas_depth_image) + if "normal" in args.pred_cond: + output_tuple = output_tuple + (normal_image) + if "canny" in args.pred_cond: + output_tuple = output_tuple + (canny_image) + if "body" in args.pred_cond: + output_tuple = output_tuple + (body_image) + if "face" in args.pred_cond: + output_tuple = output_tuple + (face_image) + if "hand" in args.pred_cond: + output_tuple = output_tuple + (hand_image) + if "whole" in args.pred_cond: + output_tuple = output_tuple + (whole_image) + return output_tuple + (has_nsfw_concept) + + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + if "depth" in args.pred_cond: + output["depth_image"] = depth_image + if "midas_depth" in args.pred_cond: + output["midas_depth_image"] = midas_depth_image + if "normal" in args.pred_cond: + output["normal_image"] = normal_image + if "canny" in args.pred_cond: + output["canny_image"] = canny_image + if "body" in args.pred_cond: + output["body_image"] = body_image + if "face" in args.pred_cond: + output["face_image"] = face_image + if "hand" in args.pred_cond: + output["hand_image"] = hand_image + if "whole" in args.pred_cond: + output["whole_image"] = whole_image + + return output diff --git a/pipelines/pipeline_stable_diffusion_spade_timemoe.py b/pipelines/pipeline_stable_diffusion_spade_timemoe.py new file mode 100644 index 0000000000000000000000000000000000000000..0f800f5b5a04a1aef1ebdd6829179ad8ef3385d6 --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_spade_timemoe.py @@ -0,0 +1,1228 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + unet2: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + new_config = dict(unet2.config) + new_config["sample_size"] = 64 + unet2._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + unet2=unet2, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.unet2, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.unet2, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + # h = 512, + # w = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + args=None, + batch=None, + depth_embedder=None, + normal_embedder=None, + canny_embedder=None, + body_embedder=None, + face_embedder=None, + hand_embedder=None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # H = torch.tensor([h * batch_size]).cuda() + # W = torch.tensor([w * batch_size]).cuda() + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + _, c, h, w = latents.shape + + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if "depth" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + depth_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + depth_latents = self.vae.encode(batch["depth"].to(latents.dtype)).latent_dist.sample() + depth_latents = depth_latents * self.vae.config.scaling_factor + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "resize": + if batch is None: + depth_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + depth_latents = F.interpolate(batch['depth'], (h,w)) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + depth_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + depth_latents = depth_embedder(batch['depth']) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + else: + assert False, "unknown condition reshape type" + + if "depth" in args.noisy_cond: + # depth_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # depth_latents = depth_latents * self.scheduler.init_noise_sigma + depth_latents = latents.clone() + + if "normal" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + normal_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + normal_latents = self.vae.encode(batch["normal"].to(latents.dtype)).latent_dist.sample() + normal_latents = normal_latents * self.vae.config.scaling_factor + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "resize": + if batch is None: + normal_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + normal_latents = F.interpolate(batch['normal'], (h,w)) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + normal_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + normal_latents = normal_embedder(batch['normal']) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + else: + assert False, "unknown condition reshape type" + + if "normal" in args.noisy_cond: + # normal_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # normal_latents = normal_latents * self.scheduler.init_noise_sigma + normal_latents = latents.clone() + + if "canny" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + canny_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + canny_latents = self.vae.encode(batch["canny"].to(latents.dtype)).latent_dist.sample() + canny_latents = canny_latents * self.vae.config.scaling_factor + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "resize": + if batch is None: + canny_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + canny_latents = F.interpolate(batch['canny'], (h,w)) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + canny_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + canny_latents = canny_embedder(batch['canny']) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + else: + assert False, "unknown condition reshape type" + + if "canny" in args.noisy_cond: + # canny_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # canny_latents = canny_latents * self.scheduler.init_noise_sigma + canny_latents = latents.clone() + + if "body" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + body_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + body_latents = self.vae.encode(batch["body"].to(latents.dtype)).latent_dist.sample() + body_latents = body_latents * self.vae.config.scaling_factor + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "resize": + if batch is None: + body_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + body_latents = F.interpolate(batch['body'], (h,w)) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + body_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + body_latents = body_embedder(batch['body']) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + else: + assert False, "unknown condition reshape type" + + if "body" in args.noisy_cond: + # body_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # body_latents = body_latents * self.scheduler.init_noise_sigma + body_latents = latents.clone() + + if "face" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + face_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + face_latents = self.vae.encode(batch["face"].to(latents.dtype)).latent_dist.sample() + face_latents = face_latents * self.vae.config.scaling_factor + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "resize": + if batch is None: + face_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + face_latents = F.interpolate(batch['face'], (h,w)) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + face_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + face_latents = face_embedder(batch['face']) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + else: + assert False, "unknown condition reshape type" + + if "face" in args.noisy_cond: + # face_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # face_latents = face_latents * self.scheduler.init_noise_sigma + face_latents = latents.clone() + + if "hand" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + hand_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + hand_latents = self.vae.encode(batch["hand"].to(latents.dtype)).latent_dist.sample() + hand_latents = hand_latents * self.vae.config.scaling_factor + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "resize": + if batch is None: + hand_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + hand_latents = F.interpolate(batch['hand'], (h,w)) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + hand_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + hand_latents = hand_embedder(batch['hand']) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + else: + assert False, "unknown condition reshape type" + + if "hand" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + hand_latents = latents.clone() + + if "ldmk" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + ldmk_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + ldmk_latents = self.vae.encode(batch["ldmk"].to(latents.dtype)).latent_dist.sample() + ldmk_latents = ldmk_latents * self.vae.config.scaling_factor + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "resize": + if batch is None: + ldmk_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + ldmk_latents = F.interpolate(batch['ldmk'], (h,w)) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + ldmk_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + ldmk_latents = hand_embedder(batch['ldmk']) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + else: + assert False, "unknown condition reshape type" + + if "ldmk" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + ldmk_latents = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if "depth" in args.noisy_cond: + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + depth_latents_input = self.scheduler.scale_model_input(depth_latents_input, t) + + if "normal" in args.noisy_cond: + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + normal_latents_input = self.scheduler.scale_model_input(normal_latents_input, t) + + if "canny" in args.noisy_cond: + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + canny_latents_input = self.scheduler.scale_model_input(canny_latents_input, t) + + if "body" in args.noisy_cond: + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + body_latents_input = self.scheduler.scale_model_input(body_latents_input, t) + + if "face" in args.noisy_cond: + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + face_latents_input = self.scheduler.scale_model_input(face_latents_input, t) + + if "hand" in args.noisy_cond: + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + hand_latents_input = self.scheduler.scale_model_input(hand_latents_input, t) + + if "ldmk" in args.noisy_cond: + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + ldmk_latents_input = self.scheduler.scale_model_input(ldmk_latents_input, t) + + _, c, h, w = latent_model_input.shape + + if args.cond_inject == "concat": + latent_model_input = torch.cat([latent_model_input, depth_latents_input], dim=1) if "depth" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, normal_latents_input], dim=1) if "normal" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, canny_latents_input], dim=1) if "canny" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, body_latents_input], dim=1) if "body" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, face_latents_input], dim=1) if "face" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, hand_latents_input], dim=1) if "hand" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, ldmk_latents_input], dim=1) if "ldmk" in args.cond_type else latent_model_input + elif args.cond_inject == "sum": + if len(args.cond_type) == 0: + pass + else: + if args.cond_reshape == "vae": + channel_dim = 4 + elif args.cond_reshape == "resize": + channel_dim = 3 + elif args.cond_reshape == "learn_conv": + channel_dim = args.embedder_channel + sum_latents = torch.zeros((latent_model_input.shape[0], channel_dim, h, w)).to(self.unet.device) + sum_latents = sum_latents + depth_latents_input if "depth" in args.cond_type else sum_latents + sum_latents = sum_latents + normal_latents_input if "normal" in args.cond_type else sum_latents + sum_latents = sum_latents + canny_latents_input if "canny" in args.cond_type else sum_latents + sum_latents = sum_latents + body_latents_input if "body" in args.cond_type else sum_latents + sum_latents = sum_latents + face_latents_input if "face" in args.cond_type else sum_latents + sum_latents = sum_latents + hand_latents_input if "hand" in args.cond_type else sum_latents + latent_model_input = torch.cat([latent_model_input, sum_latents], dim=1) + + added_cond_kwargs = {"time_ids": add_time_ids} + + # predict the noise residual + if args.cond_inject == "spade": + if batch is None: + num_cond = 0 + if "depth" in args.cond_type: num_cond += 1 + if "normal" in args.cond_type: num_cond += 1 + if "canny" in args.cond_type: num_cond += 1 + if "body" in args.cond_type: num_cond += 1 + if "face" in args.cond_type: num_cond += 1 + if "hand" in args.cond_type: num_cond += 1 + if "ldmk" in args.cond_type: num_cond += 1 + label_channels = num_cond * 3 + structural_cond = torch.zeros((batch_size, label_channels, h, w)).to(self.unet.device) + else: + structural_cond = [] + if "depth" in args.cond_type: + structural_cond.append(batch["depth"]) + if "normal" in args.cond_type: + structural_cond.append(batch["normal"]) + if "canny" in args.cond_type: + structural_cond.append(batch["canny"]) + if "body" in args.cond_type: + structural_cond.append(batch["body"]) + if "face" in args.cond_type: + structural_cond.append(batch["face"]) + if "hand" in args.cond_type: + structural_cond.append(batch["hand"]) + if "ldmk" in args.cond_type: + structural_cond.append(batch["ldmk"]) + structural_cond = torch.cat(structural_cond, dim=1) + structural_cond = torch.cat([structural_cond] * 2) if do_classifier_free_guidance else structural_cond + noise_pred = self.unet( + latent_model_input, + structural_cond, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + if t >= self.scheduler.config.num_train_timesteps // 2: + noise_pred = self.unet2( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.unet( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if noise_pred.shape[1] > 4: + cond_pred = noise_pred[:, 4:] + noise_pred = noise_pred[:, :4] + if "depth" in args.cond_type: + depth_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "normal" in args.cond_type: + normal_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "canny" in args.cond_type: + canny_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "body" in args.cond_type: + body_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "face" in args.cond_type: + face_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "hand" in args.cond_type: + hand_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if "depth" in args.noisy_cond: + depth_latents = self.scheduler.step(depth_pred, t, depth_latents, **extra_step_kwargs, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_latents = self.scheduler.step(normal_pred, t, normal_latents, **extra_step_kwargs, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_latents = self.scheduler.step(canny_pred, t, canny_latents, **extra_step_kwargs, return_dict=False)[0] + if "body" in args.noisy_cond: + body_latents = self.scheduler.step(body_pred, t, body_latents, **extra_step_kwargs, return_dict=False)[0] + if "face" in args.noisy_cond: + face_latents = self.scheduler.step(face_pred, t, face_latents, **extra_step_kwargs, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_latents = self.scheduler.step(hand_pred, t, hand_latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if "depth" in args.noisy_cond: + depth_image = self.vae.decode(depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_image = self.vae.decode(normal_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_image = self.vae.decode(canny_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "body" in args.noisy_cond: + body_image = self.vae.decode(body_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "face" in args.noisy_cond: + face_image = self.vae.decode(face_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_image = self.vae.decode(hand_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + if "depth" in args.noisy_cond: + depth_image = depth_latents + if "normal" in args.noisy_cond: + normal_image = normal_latents + if "canny" in args.noisy_cond: + canny_image = canny_latents + if "body" in args.noisy_cond: + body_image = body_latents + if "face" in args.noisy_cond: + face_image = face_latents + if "hand" in args.noisy_cond: + hand_image = hand_latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + if "depth" in args.noisy_cond: + depth_image = self.image_processor.postprocess(depth_image, output_type=output_type, do_denormalize=do_denormalize) + if "normal" in args.noisy_cond: + normal_image = self.image_processor.postprocess(normal_image, output_type=output_type, do_denormalize=do_denormalize) + if "canny" in args.noisy_cond: + canny_image = self.image_processor.postprocess(canny_image, output_type=output_type, do_denormalize=do_denormalize) + if "body" in args.noisy_cond: + body_image = self.image_processor.postprocess(body_image, output_type=output_type, do_denormalize=do_denormalize) + if "face" in args.noisy_cond: + face_image = self.image_processor.postprocess(face_image, output_type=output_type, do_denormalize=do_denormalize) + if "hand" in args.noisy_cond: + hand_image = self.image_processor.postprocess(hand_image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + output_tuple = (image) + if "depth" in args.noisy_cond: + output_tuple = output_tuple + (depth_image) + if "normal" in args.noisy_cond: + output_tuple = output_tuple + (normal_image) + if "canny" in args.noisy_cond: + output_tuple = output_tuple + (canny_image) + if "body" in args.noisy_cond: + output_tuple = output_tuple + (body_image) + if "face" in args.noisy_cond: + output_tuple = output_tuple + (face_image) + if "hand" in args.noisy_cond: + output_tuple = output_tuple + (hand_image) + return output_tuple + (has_nsfw_concept) + + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + if "depth" in args.noisy_cond: + output["depth_image"] = depth_image + if "normal" in args.noisy_cond: + output["normal_image"] = normal_image + if "canny" in args.noisy_cond: + output["canny_image"] = canny_image + if "body" in args.noisy_cond: + output["body_image"] = body_image + if "face" in args.noisy_cond: + output["face_image"] = face_image + if "hand" in args.noisy_cond: + output["hand_image"] = hand_image + + return output diff --git a/pipelines/pipeline_stable_diffusion_spade_timemoe3.py b/pipelines/pipeline_stable_diffusion_spade_timemoe3.py new file mode 100644 index 0000000000000000000000000000000000000000..873843419bec1f057c6fc1f900e07fa4ac27774f --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_spade_timemoe3.py @@ -0,0 +1,1245 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + unet2: UNet2DConditionModel, + unet3: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + new_config = dict(unet2.config) + new_config["sample_size"] = 64 + unet2._internal_dict = FrozenDict(new_config) + new_config = dict(unet3.config) + new_config["sample_size"] = 64 + unet3._internal_dict = FrozenDict(new_config) + # new_config = dict(unet4.config) + # new_config["sample_size"] = 64 + # unet4._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + unet2=unet2, + unet3=unet3, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.unet2, self.unet3, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.unet2, self.unet3, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + # h = 512, + # w = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + args=None, + batch=None, + depth_embedder=None, + normal_embedder=None, + canny_embedder=None, + body_embedder=None, + face_embedder=None, + hand_embedder=None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # H = torch.tensor([h * batch_size]).cuda() + # W = torch.tensor([w * batch_size]).cuda() + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + _, c, h, w = latents.shape + + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if "depth" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + depth_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + depth_latents = self.vae.encode(batch["depth"].to(latents.dtype)).latent_dist.sample() + depth_latents = depth_latents * self.vae.config.scaling_factor + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "resize": + if batch is None: + depth_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + depth_latents = F.interpolate(batch['depth'], (h,w)) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + depth_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + depth_latents = depth_embedder(batch['depth']) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + else: + assert False, "unknown condition reshape type" + + if "depth" in args.noisy_cond: + # depth_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # depth_latents = depth_latents * self.scheduler.init_noise_sigma + depth_latents = latents.clone() + + if "normal" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + normal_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + normal_latents = self.vae.encode(batch["normal"].to(latents.dtype)).latent_dist.sample() + normal_latents = normal_latents * self.vae.config.scaling_factor + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "resize": + if batch is None: + normal_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + normal_latents = F.interpolate(batch['normal'], (h,w)) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + normal_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + normal_latents = normal_embedder(batch['normal']) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + else: + assert False, "unknown condition reshape type" + + if "normal" in args.noisy_cond: + # normal_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # normal_latents = normal_latents * self.scheduler.init_noise_sigma + normal_latents = latents.clone() + + if "canny" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + canny_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + canny_latents = self.vae.encode(batch["canny"].to(latents.dtype)).latent_dist.sample() + canny_latents = canny_latents * self.vae.config.scaling_factor + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "resize": + if batch is None: + canny_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + canny_latents = F.interpolate(batch['canny'], (h,w)) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + canny_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + canny_latents = canny_embedder(batch['canny']) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + else: + assert False, "unknown condition reshape type" + + if "canny" in args.noisy_cond: + # canny_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # canny_latents = canny_latents * self.scheduler.init_noise_sigma + canny_latents = latents.clone() + + if "body" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + body_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + body_latents = self.vae.encode(batch["body"].to(latents.dtype)).latent_dist.sample() + body_latents = body_latents * self.vae.config.scaling_factor + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "resize": + if batch is None: + body_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + body_latents = F.interpolate(batch['body'], (h,w)) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + body_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + body_latents = body_embedder(batch['body']) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + else: + assert False, "unknown condition reshape type" + + if "body" in args.noisy_cond: + # body_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # body_latents = body_latents * self.scheduler.init_noise_sigma + body_latents = latents.clone() + + if "face" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + face_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + face_latents = self.vae.encode(batch["face"].to(latents.dtype)).latent_dist.sample() + face_latents = face_latents * self.vae.config.scaling_factor + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "resize": + if batch is None: + face_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + face_latents = F.interpolate(batch['face'], (h,w)) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + face_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + face_latents = face_embedder(batch['face']) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + else: + assert False, "unknown condition reshape type" + + if "face" in args.noisy_cond: + # face_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # face_latents = face_latents * self.scheduler.init_noise_sigma + face_latents = latents.clone() + + if "hand" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + hand_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + hand_latents = self.vae.encode(batch["hand"].to(latents.dtype)).latent_dist.sample() + hand_latents = hand_latents * self.vae.config.scaling_factor + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "resize": + if batch is None: + hand_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + hand_latents = F.interpolate(batch['hand'], (h,w)) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + hand_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + hand_latents = hand_embedder(batch['hand']) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + else: + assert False, "unknown condition reshape type" + + if "hand" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + hand_latents = latents.clone() + + if "ldmk" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + ldmk_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + ldmk_latents = self.vae.encode(batch["ldmk"].to(latents.dtype)).latent_dist.sample() + ldmk_latents = ldmk_latents * self.vae.config.scaling_factor + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "resize": + if batch is None: + ldmk_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + ldmk_latents = F.interpolate(batch['ldmk'], (h,w)) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + ldmk_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + ldmk_latents = hand_embedder(batch['ldmk']) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + else: + assert False, "unknown condition reshape type" + + if "ldmk" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + ldmk_latents = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if "depth" in args.noisy_cond: + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + depth_latents_input = self.scheduler.scale_model_input(depth_latents_input, t) + + if "normal" in args.noisy_cond: + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + normal_latents_input = self.scheduler.scale_model_input(normal_latents_input, t) + + if "canny" in args.noisy_cond: + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + canny_latents_input = self.scheduler.scale_model_input(canny_latents_input, t) + + if "body" in args.noisy_cond: + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + body_latents_input = self.scheduler.scale_model_input(body_latents_input, t) + + if "face" in args.noisy_cond: + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + face_latents_input = self.scheduler.scale_model_input(face_latents_input, t) + + if "hand" in args.noisy_cond: + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + hand_latents_input = self.scheduler.scale_model_input(hand_latents_input, t) + + if "ldmk" in args.noisy_cond: + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + ldmk_latents_input = self.scheduler.scale_model_input(ldmk_latents_input, t) + + _, c, h, w = latent_model_input.shape + + if args.cond_inject == "concat": + latent_model_input = torch.cat([latent_model_input, depth_latents_input], dim=1) if "depth" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, normal_latents_input], dim=1) if "normal" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, canny_latents_input], dim=1) if "canny" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, body_latents_input], dim=1) if "body" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, face_latents_input], dim=1) if "face" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, hand_latents_input], dim=1) if "hand" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, ldmk_latents_input], dim=1) if "ldmk" in args.cond_type else latent_model_input + elif args.cond_inject == "sum": + if len(args.cond_type) == 0: + pass + else: + if args.cond_reshape == "vae": + channel_dim = 4 + elif args.cond_reshape == "resize": + channel_dim = 3 + elif args.cond_reshape == "learn_conv": + channel_dim = args.embedder_channel + sum_latents = torch.zeros((latent_model_input.shape[0], channel_dim, h, w)).to(self.unet.device) + sum_latents = sum_latents + depth_latents_input if "depth" in args.cond_type else sum_latents + sum_latents = sum_latents + normal_latents_input if "normal" in args.cond_type else sum_latents + sum_latents = sum_latents + canny_latents_input if "canny" in args.cond_type else sum_latents + sum_latents = sum_latents + body_latents_input if "body" in args.cond_type else sum_latents + sum_latents = sum_latents + face_latents_input if "face" in args.cond_type else sum_latents + sum_latents = sum_latents + hand_latents_input if "hand" in args.cond_type else sum_latents + latent_model_input = torch.cat([latent_model_input, sum_latents], dim=1) + + added_cond_kwargs = {"time_ids": add_time_ids} + + # predict the noise residual + if args.cond_inject == "spade": + if batch is None: + num_cond = 0 + if "depth" in args.cond_type: num_cond += 1 + if "normal" in args.cond_type: num_cond += 1 + if "canny" in args.cond_type: num_cond += 1 + if "body" in args.cond_type: num_cond += 1 + if "face" in args.cond_type: num_cond += 1 + if "hand" in args.cond_type: num_cond += 1 + if "ldmk" in args.cond_type: num_cond += 1 + label_channels = num_cond * 3 + structural_cond = torch.zeros((batch_size, label_channels, h, w)).to(self.unet.device) + else: + structural_cond = [] + if "depth" in args.cond_type: + structural_cond.append(batch["depth"]) + if "normal" in args.cond_type: + structural_cond.append(batch["normal"]) + if "canny" in args.cond_type: + structural_cond.append(batch["canny"]) + if "body" in args.cond_type: + structural_cond.append(batch["body"]) + if "face" in args.cond_type: + structural_cond.append(batch["face"]) + if "hand" in args.cond_type: + structural_cond.append(batch["hand"]) + if "ldmk" in args.cond_type: + structural_cond.append(batch["ldmk"]) + structural_cond = torch.cat(structural_cond, dim=1) + structural_cond = torch.cat([structural_cond] * 2) if do_classifier_free_guidance else structural_cond + noise_pred = self.unet( + latent_model_input, + structural_cond, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + if t <= self.scheduler.config.num_train_timesteps // 4: + noise_pred = self.unet( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + elif t >= self.scheduler.config.num_train_timesteps // 4 and t <= self.scheduler.config.num_train_timesteps // 4 * 2: + noise_pred = self.unet2( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.unet3( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if noise_pred.shape[1] > 4: + cond_pred = noise_pred[:, 4:] + noise_pred = noise_pred[:, :4] + if "depth" in args.cond_type: + depth_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "normal" in args.cond_type: + normal_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "canny" in args.cond_type: + canny_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "body" in args.cond_type: + body_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "face" in args.cond_type: + face_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "hand" in args.cond_type: + hand_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if "depth" in args.noisy_cond: + depth_latents = self.scheduler.step(depth_pred, t, depth_latents, **extra_step_kwargs, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_latents = self.scheduler.step(normal_pred, t, normal_latents, **extra_step_kwargs, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_latents = self.scheduler.step(canny_pred, t, canny_latents, **extra_step_kwargs, return_dict=False)[0] + if "body" in args.noisy_cond: + body_latents = self.scheduler.step(body_pred, t, body_latents, **extra_step_kwargs, return_dict=False)[0] + if "face" in args.noisy_cond: + face_latents = self.scheduler.step(face_pred, t, face_latents, **extra_step_kwargs, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_latents = self.scheduler.step(hand_pred, t, hand_latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if "depth" in args.noisy_cond: + depth_image = self.vae.decode(depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_image = self.vae.decode(normal_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_image = self.vae.decode(canny_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "body" in args.noisy_cond: + body_image = self.vae.decode(body_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "face" in args.noisy_cond: + face_image = self.vae.decode(face_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_image = self.vae.decode(hand_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + if "depth" in args.noisy_cond: + depth_image = depth_latents + if "normal" in args.noisy_cond: + normal_image = normal_latents + if "canny" in args.noisy_cond: + canny_image = canny_latents + if "body" in args.noisy_cond: + body_image = body_latents + if "face" in args.noisy_cond: + face_image = face_latents + if "hand" in args.noisy_cond: + hand_image = hand_latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + if "depth" in args.noisy_cond: + depth_image = self.image_processor.postprocess(depth_image, output_type=output_type, do_denormalize=do_denormalize) + if "normal" in args.noisy_cond: + normal_image = self.image_processor.postprocess(normal_image, output_type=output_type, do_denormalize=do_denormalize) + if "canny" in args.noisy_cond: + canny_image = self.image_processor.postprocess(canny_image, output_type=output_type, do_denormalize=do_denormalize) + if "body" in args.noisy_cond: + body_image = self.image_processor.postprocess(body_image, output_type=output_type, do_denormalize=do_denormalize) + if "face" in args.noisy_cond: + face_image = self.image_processor.postprocess(face_image, output_type=output_type, do_denormalize=do_denormalize) + if "hand" in args.noisy_cond: + hand_image = self.image_processor.postprocess(hand_image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + output_tuple = (image) + if "depth" in args.noisy_cond: + output_tuple = output_tuple + (depth_image) + if "normal" in args.noisy_cond: + output_tuple = output_tuple + (normal_image) + if "canny" in args.noisy_cond: + output_tuple = output_tuple + (canny_image) + if "body" in args.noisy_cond: + output_tuple = output_tuple + (body_image) + if "face" in args.noisy_cond: + output_tuple = output_tuple + (face_image) + if "hand" in args.noisy_cond: + output_tuple = output_tuple + (hand_image) + return output_tuple + (has_nsfw_concept) + + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + if "depth" in args.noisy_cond: + output["depth_image"] = depth_image + if "normal" in args.noisy_cond: + output["normal_image"] = normal_image + if "canny" in args.noisy_cond: + output["canny_image"] = canny_image + if "body" in args.noisy_cond: + output["body_image"] = body_image + if "face" in args.noisy_cond: + output["face_image"] = face_image + if "hand" in args.noisy_cond: + output["hand_image"] = hand_image + + return output diff --git a/pipelines/pipeline_stable_diffusion_spade_timemoe4.py b/pipelines/pipeline_stable_diffusion_spade_timemoe4.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbdb6fad9eb85520bae37509eab6f5bc36ed15d --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_spade_timemoe4.py @@ -0,0 +1,1256 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + unet2: UNet2DConditionModel, + unet3: UNet2DConditionModel, + unet4: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + new_config = dict(unet2.config) + new_config["sample_size"] = 64 + unet2._internal_dict = FrozenDict(new_config) + new_config = dict(unet3.config) + new_config["sample_size"] = 64 + unet3._internal_dict = FrozenDict(new_config) + new_config = dict(unet4.config) + new_config["sample_size"] = 64 + unet4._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + unet2=unet2, + unet3=unet3, + unet4=unet4, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.unet2, self.unet3, self.unet4, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.unet2, self.unet3, self.unet4, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + # h = 512, + # w = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + args=None, + batch=None, + depth_embedder=None, + normal_embedder=None, + canny_embedder=None, + body_embedder=None, + face_embedder=None, + hand_embedder=None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # H = torch.tensor([h * batch_size]).cuda() + # W = torch.tensor([w * batch_size]).cuda() + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + _, c, h, w = latents.shape + + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if "depth" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + depth_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + depth_latents = self.vae.encode(batch["depth"].to(latents.dtype)).latent_dist.sample() + depth_latents = depth_latents * self.vae.config.scaling_factor + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "resize": + if batch is None: + depth_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + depth_latents = F.interpolate(batch['depth'], (h,w)) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + depth_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + depth_latents = depth_embedder(batch['depth']) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + else: + assert False, "unknown condition reshape type" + + if "depth" in args.noisy_cond: + # depth_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # depth_latents = depth_latents * self.scheduler.init_noise_sigma + depth_latents = latents.clone() + + if "normal" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + normal_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + normal_latents = self.vae.encode(batch["normal"].to(latents.dtype)).latent_dist.sample() + normal_latents = normal_latents * self.vae.config.scaling_factor + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "resize": + if batch is None: + normal_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + normal_latents = F.interpolate(batch['normal'], (h,w)) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + normal_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + normal_latents = normal_embedder(batch['normal']) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + else: + assert False, "unknown condition reshape type" + + if "normal" in args.noisy_cond: + # normal_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # normal_latents = normal_latents * self.scheduler.init_noise_sigma + normal_latents = latents.clone() + + if "canny" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + canny_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + canny_latents = self.vae.encode(batch["canny"].to(latents.dtype)).latent_dist.sample() + canny_latents = canny_latents * self.vae.config.scaling_factor + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "resize": + if batch is None: + canny_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + canny_latents = F.interpolate(batch['canny'], (h,w)) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + canny_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + canny_latents = canny_embedder(batch['canny']) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + else: + assert False, "unknown condition reshape type" + + if "canny" in args.noisy_cond: + # canny_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # canny_latents = canny_latents * self.scheduler.init_noise_sigma + canny_latents = latents.clone() + + if "body" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + body_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + body_latents = self.vae.encode(batch["body"].to(latents.dtype)).latent_dist.sample() + body_latents = body_latents * self.vae.config.scaling_factor + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "resize": + if batch is None: + body_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + body_latents = F.interpolate(batch['body'], (h,w)) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + body_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + body_latents = body_embedder(batch['body']) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + else: + assert False, "unknown condition reshape type" + + if "body" in args.noisy_cond: + # body_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # body_latents = body_latents * self.scheduler.init_noise_sigma + body_latents = latents.clone() + + if "face" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + face_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + face_latents = self.vae.encode(batch["face"].to(latents.dtype)).latent_dist.sample() + face_latents = face_latents * self.vae.config.scaling_factor + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "resize": + if batch is None: + face_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + face_latents = F.interpolate(batch['face'], (h,w)) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + face_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + face_latents = face_embedder(batch['face']) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + else: + assert False, "unknown condition reshape type" + + if "face" in args.noisy_cond: + # face_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # face_latents = face_latents * self.scheduler.init_noise_sigma + face_latents = latents.clone() + + if "hand" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + hand_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + hand_latents = self.vae.encode(batch["hand"].to(latents.dtype)).latent_dist.sample() + hand_latents = hand_latents * self.vae.config.scaling_factor + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "resize": + if batch is None: + hand_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + hand_latents = F.interpolate(batch['hand'], (h,w)) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + hand_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + hand_latents = hand_embedder(batch['hand']) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + else: + assert False, "unknown condition reshape type" + + if "hand" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + hand_latents = latents.clone() + + if "ldmk" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + ldmk_latents = torch.zeros((batch_size, c, h, w)).to(self.unet.device) + else: + ldmk_latents = self.vae.encode(batch["ldmk"].to(latents.dtype)).latent_dist.sample() + ldmk_latents = ldmk_latents * self.vae.config.scaling_factor + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "resize": + if batch is None: + ldmk_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet.device) + else: + ldmk_latents = F.interpolate(batch['ldmk'], (h,w)) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + ldmk_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet.device) + else: + ldmk_latents = hand_embedder(batch['ldmk']) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + else: + assert False, "unknown condition reshape type" + + if "ldmk" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + ldmk_latents = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if "depth" in args.noisy_cond: + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + depth_latents_input = self.scheduler.scale_model_input(depth_latents_input, t) + + if "normal" in args.noisy_cond: + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + normal_latents_input = self.scheduler.scale_model_input(normal_latents_input, t) + + if "canny" in args.noisy_cond: + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + canny_latents_input = self.scheduler.scale_model_input(canny_latents_input, t) + + if "body" in args.noisy_cond: + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + body_latents_input = self.scheduler.scale_model_input(body_latents_input, t) + + if "face" in args.noisy_cond: + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + face_latents_input = self.scheduler.scale_model_input(face_latents_input, t) + + if "hand" in args.noisy_cond: + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + hand_latents_input = self.scheduler.scale_model_input(hand_latents_input, t) + + if "ldmk" in args.noisy_cond: + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + ldmk_latents_input = self.scheduler.scale_model_input(ldmk_latents_input, t) + + _, c, h, w = latent_model_input.shape + + if args.cond_inject == "concat": + latent_model_input = torch.cat([latent_model_input, depth_latents_input], dim=1) if "depth" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, normal_latents_input], dim=1) if "normal" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, canny_latents_input], dim=1) if "canny" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, body_latents_input], dim=1) if "body" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, face_latents_input], dim=1) if "face" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, hand_latents_input], dim=1) if "hand" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, ldmk_latents_input], dim=1) if "ldmk" in args.cond_type else latent_model_input + elif args.cond_inject == "sum": + if len(args.cond_type) == 0: + pass + else: + if args.cond_reshape == "vae": + channel_dim = 4 + elif args.cond_reshape == "resize": + channel_dim = 3 + elif args.cond_reshape == "learn_conv": + channel_dim = args.embedder_channel + sum_latents = torch.zeros((latent_model_input.shape[0], channel_dim, h, w)).to(self.unet.device) + sum_latents = sum_latents + depth_latents_input if "depth" in args.cond_type else sum_latents + sum_latents = sum_latents + normal_latents_input if "normal" in args.cond_type else sum_latents + sum_latents = sum_latents + canny_latents_input if "canny" in args.cond_type else sum_latents + sum_latents = sum_latents + body_latents_input if "body" in args.cond_type else sum_latents + sum_latents = sum_latents + face_latents_input if "face" in args.cond_type else sum_latents + sum_latents = sum_latents + hand_latents_input if "hand" in args.cond_type else sum_latents + latent_model_input = torch.cat([latent_model_input, sum_latents], dim=1) + + added_cond_kwargs = {"time_ids": add_time_ids} + + # predict the noise residual + if args.cond_inject == "spade": + if batch is None: + num_cond = 0 + if "depth" in args.cond_type: num_cond += 1 + if "normal" in args.cond_type: num_cond += 1 + if "canny" in args.cond_type: num_cond += 1 + if "body" in args.cond_type: num_cond += 1 + if "face" in args.cond_type: num_cond += 1 + if "hand" in args.cond_type: num_cond += 1 + if "ldmk" in args.cond_type: num_cond += 1 + label_channels = num_cond * 3 + structural_cond = torch.zeros((batch_size, label_channels, h, w)).to(self.unet.device) + else: + structural_cond = [] + if "depth" in args.cond_type: + structural_cond.append(batch["depth"]) + if "normal" in args.cond_type: + structural_cond.append(batch["normal"]) + if "canny" in args.cond_type: + structural_cond.append(batch["canny"]) + if "body" in args.cond_type: + structural_cond.append(batch["body"]) + if "face" in args.cond_type: + structural_cond.append(batch["face"]) + if "hand" in args.cond_type: + structural_cond.append(batch["hand"]) + if "ldmk" in args.cond_type: + structural_cond.append(batch["ldmk"]) + structural_cond = torch.cat(structural_cond, dim=1) + structural_cond = torch.cat([structural_cond] * 2) if do_classifier_free_guidance else structural_cond + noise_pred = self.unet( + latent_model_input, + structural_cond, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + if t <= self.scheduler.config.num_train_timesteps // 4: + noise_pred = self.unet( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + elif t >= self.scheduler.config.num_train_timesteps // 4 and t <= self.scheduler.config.num_train_timesteps // 4 * 2: + noise_pred = self.unet2( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + elif t >= self.scheduler.config.num_train_timesteps // 4 * 2 and t <= self.scheduler.config.num_train_timesteps // 4 * 3: + noise_pred = self.unet3( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.unet4( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if noise_pred.shape[1] > 4: + cond_pred = noise_pred[:, 4:] + noise_pred = noise_pred[:, :4] + if "depth" in args.cond_type: + depth_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "normal" in args.cond_type: + normal_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "canny" in args.cond_type: + canny_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "body" in args.cond_type: + body_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "face" in args.cond_type: + face_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "hand" in args.cond_type: + hand_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if "depth" in args.noisy_cond: + depth_latents = self.scheduler.step(depth_pred, t, depth_latents, **extra_step_kwargs, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_latents = self.scheduler.step(normal_pred, t, normal_latents, **extra_step_kwargs, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_latents = self.scheduler.step(canny_pred, t, canny_latents, **extra_step_kwargs, return_dict=False)[0] + if "body" in args.noisy_cond: + body_latents = self.scheduler.step(body_pred, t, body_latents, **extra_step_kwargs, return_dict=False)[0] + if "face" in args.noisy_cond: + face_latents = self.scheduler.step(face_pred, t, face_latents, **extra_step_kwargs, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_latents = self.scheduler.step(hand_pred, t, hand_latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if "depth" in args.noisy_cond: + depth_image = self.vae.decode(depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_image = self.vae.decode(normal_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_image = self.vae.decode(canny_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "body" in args.noisy_cond: + body_image = self.vae.decode(body_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "face" in args.noisy_cond: + face_image = self.vae.decode(face_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_image = self.vae.decode(hand_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + if "depth" in args.noisy_cond: + depth_image = depth_latents + if "normal" in args.noisy_cond: + normal_image = normal_latents + if "canny" in args.noisy_cond: + canny_image = canny_latents + if "body" in args.noisy_cond: + body_image = body_latents + if "face" in args.noisy_cond: + face_image = face_latents + if "hand" in args.noisy_cond: + hand_image = hand_latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + if "depth" in args.noisy_cond: + depth_image = self.image_processor.postprocess(depth_image, output_type=output_type, do_denormalize=do_denormalize) + if "normal" in args.noisy_cond: + normal_image = self.image_processor.postprocess(normal_image, output_type=output_type, do_denormalize=do_denormalize) + if "canny" in args.noisy_cond: + canny_image = self.image_processor.postprocess(canny_image, output_type=output_type, do_denormalize=do_denormalize) + if "body" in args.noisy_cond: + body_image = self.image_processor.postprocess(body_image, output_type=output_type, do_denormalize=do_denormalize) + if "face" in args.noisy_cond: + face_image = self.image_processor.postprocess(face_image, output_type=output_type, do_denormalize=do_denormalize) + if "hand" in args.noisy_cond: + hand_image = self.image_processor.postprocess(hand_image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + output_tuple = (image) + if "depth" in args.noisy_cond: + output_tuple = output_tuple + (depth_image) + if "normal" in args.noisy_cond: + output_tuple = output_tuple + (normal_image) + if "canny" in args.noisy_cond: + output_tuple = output_tuple + (canny_image) + if "body" in args.noisy_cond: + output_tuple = output_tuple + (body_image) + if "face" in args.noisy_cond: + output_tuple = output_tuple + (face_image) + if "hand" in args.noisy_cond: + output_tuple = output_tuple + (hand_image) + return output_tuple + (has_nsfw_concept) + + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + if "depth" in args.noisy_cond: + output["depth_image"] = depth_image + if "normal" in args.noisy_cond: + output["normal_image"] = normal_image + if "canny" in args.noisy_cond: + output["canny_image"] = canny_image + if "body" in args.noisy_cond: + output["body_image"] = body_image + if "face" in args.noisy_cond: + output["face_image"] = face_image + if "hand" in args.noisy_cond: + output["hand_image"] = hand_image + + return output diff --git a/pipelines/pipeline_stable_diffusion_spade_timemoe5.py b/pipelines/pipeline_stable_diffusion_spade_timemoe5.py new file mode 100644 index 0000000000000000000000000000000000000000..6e225b790bb243555b36deb32200e6315974ca66 --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_spade_timemoe5.py @@ -0,0 +1,1270 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet_0_200 ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet_0_200` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet_0_200: UNet2DConditionModel, + unet_200_400: UNet2DConditionModel, + unet_400_600: UNet2DConditionModel, + unet_600_800: UNet2DConditionModel, + unet_800_1000: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet_0_200.config, "_diffusers_version") and version.parse( + version.parse(unet_0_200.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet_0_200.config, "sample_size") and unet_0_200.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet_0_200 has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet_0_200/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet_0_200.config) + new_config["sample_size"] = 64 + unet_0_200._internal_dict = FrozenDict(new_config) + new_config = dict(unet_200_400.config) + new_config["sample_size"] = 64 + unet_200_400._internal_dict = FrozenDict(new_config) + new_config = dict(unet_400_600.config) + new_config["sample_size"] = 64 + unet_400_600._internal_dict = FrozenDict(new_config) + new_config = dict(unet_600_800.config) + new_config["sample_size"] = 64 + unet_600_800._internal_dict = FrozenDict(new_config) + new_config = dict(unet_800_1000.config) + new_config["sample_size"] = 64 + unet_800_1000._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet_0_200=unet_0_200, + unet_200_400=unet_200_400, + unet_400_600=unet_400_600, + unet_600_800=unet_600_800, + unet_800_1000=unet_800_1000, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet_0_200, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet_0_200, self.unet_200_400, self.unet_400_600, self.unet_600_800, self.unet_800_1000, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet_0_200`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet_0_200, self.unet_200_400, self.unet_400_600, self.unet_600_800, self.unet_800_1000, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet_0_200, "_hf_hook"): + return self.device + for module in self.unet_0_200.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + # h = 512, + # w = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + args=None, + batch=None, + depth_embedder=None, + normal_embedder=None, + canny_embedder=None, + body_embedder=None, + face_embedder=None, + hand_embedder=None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet_0_200.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet_0_200.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet_0_200 + height = height or self.unet_0_200.config.sample_size * self.vae_scale_factor + width = width or self.unet_0_200.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # H = torch.tensor([h * batch_size]).cuda() + # W = torch.tensor([w * batch_size]).cuda() + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet_0_200.config.in_channels + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + _, c, h, w = latents.shape + + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if "depth" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + depth_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + depth_latents = self.vae.encode(batch["depth"].to(latents.dtype)).latent_dist.sample() + depth_latents = depth_latents * self.vae.config.scaling_factor + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "resize": + if batch is None: + depth_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + depth_latents = F.interpolate(batch['depth'], (h,w)) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + depth_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + depth_latents = depth_embedder(batch['depth']) + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + else: + assert False, "unknown condition reshape type" + + if "depth" in args.noisy_cond: + # depth_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # depth_latents = depth_latents * self.scheduler.init_noise_sigma + depth_latents = latents.clone() + + if "normal" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + normal_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + normal_latents = self.vae.encode(batch["normal"].to(latents.dtype)).latent_dist.sample() + normal_latents = normal_latents * self.vae.config.scaling_factor + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "resize": + if batch is None: + normal_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + normal_latents = F.interpolate(batch['normal'], (h,w)) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + normal_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + normal_latents = normal_embedder(batch['normal']) + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + else: + assert False, "unknown condition reshape type" + + if "normal" in args.noisy_cond: + # normal_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # normal_latents = normal_latents * self.scheduler.init_noise_sigma + normal_latents = latents.clone() + + if "canny" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + canny_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + canny_latents = self.vae.encode(batch["canny"].to(latents.dtype)).latent_dist.sample() + canny_latents = canny_latents * self.vae.config.scaling_factor + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "resize": + if batch is None: + canny_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + canny_latents = F.interpolate(batch['canny'], (h,w)) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + canny_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + canny_latents = canny_embedder(batch['canny']) + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + else: + assert False, "unknown condition reshape type" + + if "canny" in args.noisy_cond: + # canny_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # canny_latents = canny_latents * self.scheduler.init_noise_sigma + canny_latents = latents.clone() + + if "body" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + body_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + body_latents = self.vae.encode(batch["body"].to(latents.dtype)).latent_dist.sample() + body_latents = body_latents * self.vae.config.scaling_factor + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "resize": + if batch is None: + body_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + body_latents = F.interpolate(batch['body'], (h,w)) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + body_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + body_latents = body_embedder(batch['body']) + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + else: + assert False, "unknown condition reshape type" + + if "body" in args.noisy_cond: + # body_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # body_latents = body_latents * self.scheduler.init_noise_sigma + body_latents = latents.clone() + + if "face" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + face_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + face_latents = self.vae.encode(batch["face"].to(latents.dtype)).latent_dist.sample() + face_latents = face_latents * self.vae.config.scaling_factor + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "resize": + if batch is None: + face_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + face_latents = F.interpolate(batch['face'], (h,w)) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + face_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + face_latents = face_embedder(batch['face']) + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + else: + assert False, "unknown condition reshape type" + + if "face" in args.noisy_cond: + # face_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # face_latents = face_latents * self.scheduler.init_noise_sigma + face_latents = latents.clone() + + if "hand" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + hand_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + hand_latents = self.vae.encode(batch["hand"].to(latents.dtype)).latent_dist.sample() + hand_latents = hand_latents * self.vae.config.scaling_factor + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "resize": + if batch is None: + hand_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + hand_latents = F.interpolate(batch['hand'], (h,w)) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + hand_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + hand_latents = hand_embedder(batch['hand']) + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + else: + assert False, "unknown condition reshape type" + + if "hand" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + hand_latents = latents.clone() + + if "ldmk" in args.cond_type: + if args.cond_reshape == "vae": + if batch is None: + ldmk_latents = torch.zeros((batch_size, c, h, w)).to(self.unet_0_200.device) + else: + ldmk_latents = self.vae.encode(batch["ldmk"].to(latents.dtype)).latent_dist.sample() + ldmk_latents = ldmk_latents * self.vae.config.scaling_factor + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "resize": + if batch is None: + ldmk_latents = torch.zeros((batch_size, 3, h, w)).to(self.unet_0_200.device) + else: + ldmk_latents = F.interpolate(batch['ldmk'], (h,w)) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + elif args.cond_reshape == "learn_conv": + if batch is None: + ldmk_latents = torch.zeros((batch_size, args.embedder_channel, h, w)).to(self.unet_0_200.device) + else: + ldmk_latents = hand_embedder(batch['ldmk']) + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + else: + assert False, "unknown condition reshape type" + + if "ldmk" in args.noisy_cond: + # hand_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + # hand_latents = hand_latents * self.scheduler.init_noise_sigma + ldmk_latents = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if "depth" in args.noisy_cond: + depth_latents_input = torch.cat([depth_latents] * 2) if do_classifier_free_guidance else depth_latents + depth_latents_input = self.scheduler.scale_model_input(depth_latents_input, t) + + if "normal" in args.noisy_cond: + normal_latents_input = torch.cat([normal_latents] * 2) if do_classifier_free_guidance else normal_latents + normal_latents_input = self.scheduler.scale_model_input(normal_latents_input, t) + + if "canny" in args.noisy_cond: + canny_latents_input = torch.cat([canny_latents] * 2) if do_classifier_free_guidance else canny_latents + canny_latents_input = self.scheduler.scale_model_input(canny_latents_input, t) + + if "body" in args.noisy_cond: + body_latents_input = torch.cat([body_latents] * 2) if do_classifier_free_guidance else body_latents + body_latents_input = self.scheduler.scale_model_input(body_latents_input, t) + + if "face" in args.noisy_cond: + face_latents_input = torch.cat([face_latents] * 2) if do_classifier_free_guidance else face_latents + face_latents_input = self.scheduler.scale_model_input(face_latents_input, t) + + if "hand" in args.noisy_cond: + hand_latents_input = torch.cat([hand_latents] * 2) if do_classifier_free_guidance else hand_latents + hand_latents_input = self.scheduler.scale_model_input(hand_latents_input, t) + + if "ldmk" in args.noisy_cond: + ldmk_latents_input = torch.cat([ldmk_latents] * 2) if do_classifier_free_guidance else ldmk_latents + ldmk_latents_input = self.scheduler.scale_model_input(ldmk_latents_input, t) + + _, c, h, w = latent_model_input.shape + + if args.cond_inject == "concat": + latent_model_input = torch.cat([latent_model_input, depth_latents_input], dim=1) if "depth" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, normal_latents_input], dim=1) if "normal" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, canny_latents_input], dim=1) if "canny" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, body_latents_input], dim=1) if "body" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, face_latents_input], dim=1) if "face" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, hand_latents_input], dim=1) if "hand" in args.cond_type else latent_model_input + latent_model_input = torch.cat([latent_model_input, ldmk_latents_input], dim=1) if "ldmk" in args.cond_type else latent_model_input + elif args.cond_inject == "sum": + if len(args.cond_type) == 0: + pass + else: + if args.cond_reshape == "vae": + channel_dim = 4 + elif args.cond_reshape == "resize": + channel_dim = 3 + elif args.cond_reshape == "learn_conv": + channel_dim = args.embedder_channel + sum_latents = torch.zeros((latent_model_input.shape[0], channel_dim, h, w)).to(self.unet_0_200.device) + sum_latents = sum_latents + depth_latents_input if "depth" in args.cond_type else sum_latents + sum_latents = sum_latents + normal_latents_input if "normal" in args.cond_type else sum_latents + sum_latents = sum_latents + canny_latents_input if "canny" in args.cond_type else sum_latents + sum_latents = sum_latents + body_latents_input if "body" in args.cond_type else sum_latents + sum_latents = sum_latents + face_latents_input if "face" in args.cond_type else sum_latents + sum_latents = sum_latents + hand_latents_input if "hand" in args.cond_type else sum_latents + latent_model_input = torch.cat([latent_model_input, sum_latents], dim=1) + + added_cond_kwargs = {"time_ids": add_time_ids} + + # predict the noise residual + if args.cond_inject == "spade": + if batch is None: + num_cond = 0 + if "depth" in args.cond_type: num_cond += 1 + if "normal" in args.cond_type: num_cond += 1 + if "canny" in args.cond_type: num_cond += 1 + if "body" in args.cond_type: num_cond += 1 + if "face" in args.cond_type: num_cond += 1 + if "hand" in args.cond_type: num_cond += 1 + if "ldmk" in args.cond_type: num_cond += 1 + label_channels = num_cond * 3 + structural_cond = torch.zeros((batch_size, label_channels, h, w)).to(self.unet_0_200.device) + else: + structural_cond = [] + if "depth" in args.cond_type: + structural_cond.append(batch["depth"]) + if "normal" in args.cond_type: + structural_cond.append(batch["normal"]) + if "canny" in args.cond_type: + structural_cond.append(batch["canny"]) + if "body" in args.cond_type: + structural_cond.append(batch["body"]) + if "face" in args.cond_type: + structural_cond.append(batch["face"]) + if "hand" in args.cond_type: + structural_cond.append(batch["hand"]) + if "ldmk" in args.cond_type: + structural_cond.append(batch["ldmk"]) + structural_cond = torch.cat(structural_cond, dim=1) + structural_cond = torch.cat([structural_cond] * 2) if do_classifier_free_guidance else structural_cond + noise_pred = self.unet_0_200( + latent_model_input, + structural_cond, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + if t < 200: + noise_pred = self.unet_0_200( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + elif t >= 200 and t < 400: + noise_pred = self.unet_200_400( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + elif t >= 400 and t < 600: + noise_pred = self.unet_400_600( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + elif t >= 600 and t < 800: + noise_pred = self.unet_600_800( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.unet_800_1000( + latent_model_input, + t, + added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if noise_pred.shape[1] > 4: + cond_pred = noise_pred[:, 4:] + noise_pred = noise_pred[:, :4] + if "depth" in args.cond_type: + depth_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "normal" in args.cond_type: + normal_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "canny" in args.cond_type: + canny_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "body" in args.cond_type: + body_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "face" in args.cond_type: + face_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + if "hand" in args.cond_type: + hand_pred = cond_pred[:, :4] + if cond_pred.shape[1] > 4: + cond_pred = cond_pred[:, 4:] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if "depth" in args.noisy_cond: + depth_latents = self.scheduler.step(depth_pred, t, depth_latents, **extra_step_kwargs, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_latents = self.scheduler.step(normal_pred, t, normal_latents, **extra_step_kwargs, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_latents = self.scheduler.step(canny_pred, t, canny_latents, **extra_step_kwargs, return_dict=False)[0] + if "body" in args.noisy_cond: + body_latents = self.scheduler.step(body_pred, t, body_latents, **extra_step_kwargs, return_dict=False)[0] + if "face" in args.noisy_cond: + face_latents = self.scheduler.step(face_pred, t, face_latents, **extra_step_kwargs, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_latents = self.scheduler.step(hand_pred, t, hand_latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if "depth" in args.noisy_cond: + depth_image = self.vae.decode(depth_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "normal" in args.noisy_cond: + normal_image = self.vae.decode(normal_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "canny" in args.noisy_cond: + canny_image = self.vae.decode(canny_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "body" in args.noisy_cond: + body_image = self.vae.decode(body_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "face" in args.noisy_cond: + face_image = self.vae.decode(face_latents / self.vae.config.scaling_factor, return_dict=False)[0] + if "hand" in args.noisy_cond: + hand_image = self.vae.decode(hand_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + if "depth" in args.noisy_cond: + depth_image = depth_latents + if "normal" in args.noisy_cond: + normal_image = normal_latents + if "canny" in args.noisy_cond: + canny_image = canny_latents + if "body" in args.noisy_cond: + body_image = body_latents + if "face" in args.noisy_cond: + face_image = face_latents + if "hand" in args.noisy_cond: + hand_image = hand_latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + if "depth" in args.noisy_cond: + depth_image = self.image_processor.postprocess(depth_image, output_type=output_type, do_denormalize=do_denormalize) + if "normal" in args.noisy_cond: + normal_image = self.image_processor.postprocess(normal_image, output_type=output_type, do_denormalize=do_denormalize) + if "canny" in args.noisy_cond: + canny_image = self.image_processor.postprocess(canny_image, output_type=output_type, do_denormalize=do_denormalize) + if "body" in args.noisy_cond: + body_image = self.image_processor.postprocess(body_image, output_type=output_type, do_denormalize=do_denormalize) + if "face" in args.noisy_cond: + face_image = self.image_processor.postprocess(face_image, output_type=output_type, do_denormalize=do_denormalize) + if "hand" in args.noisy_cond: + hand_image = self.image_processor.postprocess(hand_image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + output_tuple = (image) + if "depth" in args.noisy_cond: + output_tuple = output_tuple + (depth_image) + if "normal" in args.noisy_cond: + output_tuple = output_tuple + (normal_image) + if "canny" in args.noisy_cond: + output_tuple = output_tuple + (canny_image) + if "body" in args.noisy_cond: + output_tuple = output_tuple + (body_image) + if "face" in args.noisy_cond: + output_tuple = output_tuple + (face_image) + if "hand" in args.noisy_cond: + output_tuple = output_tuple + (hand_image) + return output_tuple + (has_nsfw_concept) + + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + if "depth" in args.noisy_cond: + output["depth_image"] = depth_image + if "normal" in args.noisy_cond: + output["normal_image"] = normal_image + if "canny" in args.noisy_cond: + output["canny_image"] = canny_image + if "body" in args.noisy_cond: + output["body_image"] = body_image + if "face" in args.noisy_cond: + output["face_image"] = face_image + if "hand" in args.noisy_cond: + output["hand_image"] = hand_image + + return output diff --git a/pipelines/pipeline_stable_diffusion_xl.py b/pipelines/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..5e0af0c4c5a68c86758e656a8da29ef174ccd8e5 --- /dev/null +++ b/pipelines/pipeline_stable_diffusion_xl.py @@ -0,0 +1,812 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + + self.watermark = StableDiffusionXLWatermarker() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to + 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) + denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The + denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of + Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None: + num_inference_steps = int(round(denoising_end * num_inference_steps)) + timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..8d21ad680782ef9e1f29897823ac6763ee0c32cb --- /dev/null +++ b/run.sh @@ -0,0 +1 @@ +python gradio_humanpose2image.py diff --git a/share.py b/share.py new file mode 100644 index 0000000000000000000000000000000000000000..463af08fb936d650b5dd2e66183661181c34a3d6 --- /dev/null +++ b/share.py @@ -0,0 +1,8 @@ +import config +from cldm.hack import disable_verbosity, enable_sliced_attention + + +disable_verbosity() + +if config.save_memory: + enable_sliced_attention() diff --git a/test_imgs/bag.png b/test_imgs/bag.png new file mode 100644 index 0000000000000000000000000000000000000000..ae25f755061850f98bae8b336ee834473a923fe9 Binary files /dev/null and b/test_imgs/bag.png differ diff --git a/test_imgs/bag_scribble.png b/test_imgs/bag_scribble.png new file mode 100644 index 0000000000000000000000000000000000000000..9d5991ce12fe04517dc27cfea12fd1bb72ada649 Binary files /dev/null and b/test_imgs/bag_scribble.png differ diff --git a/test_imgs/bird.png b/test_imgs/bird.png new file mode 100644 index 0000000000000000000000000000000000000000..544da68fdbdfda5befa0228d5fbc740d842bf766 --- /dev/null +++ b/test_imgs/bird.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cad49fc7d3071b2bcd078bc8dde365f8fa62eaa6d43705fd50c212794a3aac35 +size 1065314 diff --git a/test_imgs/boy.png b/test_imgs/boy.png new file mode 100644 index 0000000000000000000000000000000000000000..c4a751e31da45af83c8a3d5ec02cf8c22c7bb8e9 Binary files /dev/null and b/test_imgs/boy.png differ diff --git a/test_imgs/building.png b/test_imgs/building.png new file mode 100644 index 0000000000000000000000000000000000000000..7ecf60a6f021e906e1781c9d4a6c77fea0fa2302 --- /dev/null +++ b/test_imgs/building.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab98191cab43d1cf19c408cb900250fb8a538d114979b556990a053b0ecb788c +size 1135527 diff --git a/test_imgs/building2.png b/test_imgs/building2.png new file mode 100644 index 0000000000000000000000000000000000000000..8ceff41cdefefe18f0ceeab34325be2c01877ec1 --- /dev/null +++ b/test_imgs/building2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:321cea509578013fe503a64fc0ce718794afbaf2148f8aa0b22890d99995e336 +size 1060076 diff --git a/test_imgs/cute_robot.png b/test_imgs/cute_robot.png new file mode 100644 index 0000000000000000000000000000000000000000..74efaa346a7ecf1a0b6a03771901fe986c40532a Binary files /dev/null and b/test_imgs/cute_robot.png differ diff --git a/test_imgs/cyber.png b/test_imgs/cyber.png new file mode 100644 index 0000000000000000000000000000000000000000..b3eb747c055733ecd94eb6a3ff058f39ce738109 Binary files /dev/null and b/test_imgs/cyber.png differ diff --git a/test_imgs/dog.png b/test_imgs/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..022c2d13a1b7d79a979ccd6a51b80c1181d06478 Binary files /dev/null and b/test_imgs/dog.png differ diff --git a/test_imgs/dog2.png b/test_imgs/dog2.png new file mode 100644 index 0000000000000000000000000000000000000000..481e3763be2cdd28f2de4558752d659d40bf3232 Binary files /dev/null and b/test_imgs/dog2.png differ diff --git a/test_imgs/house.png b/test_imgs/house.png new file mode 100644 index 0000000000000000000000000000000000000000..f46f6700781e157884b4325dac3835b102eb7212 Binary files /dev/null and b/test_imgs/house.png differ diff --git a/test_imgs/house_line.png b/test_imgs/house_line.png new file mode 100644 index 0000000000000000000000000000000000000000..d5e2460de33d9f5d212c13428fd7cdaa06e459bc Binary files /dev/null and b/test_imgs/house_line.png differ diff --git a/test_imgs/human.png b/test_imgs/human.png new file mode 100644 index 0000000000000000000000000000000000000000..646628c758479f5401618c4e49ded080b3db00b1 Binary files /dev/null and b/test_imgs/human.png differ diff --git a/test_imgs/human_line.png b/test_imgs/human_line.png new file mode 100644 index 0000000000000000000000000000000000000000..cfa67539697ddb5c9a1f1ba33f5f6b7d6ca119ed Binary files /dev/null and b/test_imgs/human_line.png differ diff --git a/test_imgs/man.png b/test_imgs/man.png new file mode 100644 index 0000000000000000000000000000000000000000..680889ed588457c6fe50dca3a4ed0f8451756d61 Binary files /dev/null and b/test_imgs/man.png differ diff --git a/test_imgs/old.png b/test_imgs/old.png new file mode 100644 index 0000000000000000000000000000000000000000..16fd6699817936f8d53b9eb2e6d672f1c60e30b7 Binary files /dev/null and b/test_imgs/old.png differ diff --git a/test_imgs/pose1.png b/test_imgs/pose1.png new file mode 100644 index 0000000000000000000000000000000000000000..e566939b171a8babae471125c753446b2368bf9c Binary files /dev/null and b/test_imgs/pose1.png differ diff --git a/test_imgs/pose2.png b/test_imgs/pose2.png new file mode 100644 index 0000000000000000000000000000000000000000..ba4ba5822b6ddcd2a8b642be42ec29179343fc20 Binary files /dev/null and b/test_imgs/pose2.png differ diff --git a/test_imgs/room.png b/test_imgs/room.png new file mode 100644 index 0000000000000000000000000000000000000000..318dcf26ef005d78a21eab452e80fc9b4f10c134 Binary files /dev/null and b/test_imgs/room.png differ diff --git a/test_imgs/sd.png b/test_imgs/sd.png new file mode 100644 index 0000000000000000000000000000000000000000..f32273ca8c63349c30d244e02ff31348aaaaa723 Binary files /dev/null and b/test_imgs/sd.png differ diff --git a/test_imgs/shose.png b/test_imgs/shose.png new file mode 100644 index 0000000000000000000000000000000000000000..cb003809fbd04825330c246d927783e48d0840d5 Binary files /dev/null and b/test_imgs/shose.png differ diff --git a/test_imgs/toy.png b/test_imgs/toy.png new file mode 100644 index 0000000000000000000000000000000000000000..77c70c5b99fec55c2b184345055a39428bdc6298 Binary files /dev/null and b/test_imgs/toy.png differ diff --git a/test_imgs/user_1.png b/test_imgs/user_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c117fe1c4d8bceacd786819554998445957e15e1 Binary files /dev/null and b/test_imgs/user_1.png differ diff --git a/test_imgs/user_3.png b/test_imgs/user_3.png new file mode 100644 index 0000000000000000000000000000000000000000..bd1cf3029f5061e2d6c62599d68c0279bc2c2970 Binary files /dev/null and b/test_imgs/user_3.png differ diff --git a/tool_add_control.py b/tool_add_control.py new file mode 100644 index 0000000000000000000000000000000000000000..8076b5143405e5516b063f4fd63096f65cffbed2 --- /dev/null +++ b/tool_add_control.py @@ -0,0 +1,50 @@ +import sys +import os + +assert len(sys.argv) == 3, 'Args are wrong.' + +input_path = sys.argv[1] +output_path = sys.argv[2] + +assert os.path.exists(input_path), 'Input model does not exist.' +assert not os.path.exists(output_path), 'Output filename already exists.' +assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' + +import torch +from share import * +from cldm.model import create_model + + +def get_node_name(name, parent_name): + if len(name) <= len(parent_name): + return False, '' + p = name[:len(parent_name)] + if p != parent_name: + return False, '' + return True, name[len(parent_name):] + + +model = create_model(config_path='./models/cldm_v15.yaml') + +pretrained_weights = torch.load(input_path) +if 'state_dict' in pretrained_weights: + pretrained_weights = pretrained_weights['state_dict'] + +scratch_dict = model.state_dict() + +target_dict = {} +for k in scratch_dict.keys(): + is_control, name = get_node_name(k, 'control_') + if is_control: + copy_k = 'model.diffusion_' + name + else: + copy_k = k + if copy_k in pretrained_weights: + target_dict[k] = pretrained_weights[copy_k].clone() + else: + target_dict[k] = scratch_dict[k].clone() + print(f'These weights are newly added: {k}') + +model.load_state_dict(target_dict, strict=True) +torch.save(model.state_dict(), output_path) +print('Done.') diff --git a/tool_add_control_sd21.py b/tool_add_control_sd21.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3ac5f9bb398b9d20cb7d750f1f7c0717670eae --- /dev/null +++ b/tool_add_control_sd21.py @@ -0,0 +1,50 @@ +import sys +import os + +assert len(sys.argv) == 3, 'Args are wrong.' + +input_path = sys.argv[1] +output_path = sys.argv[2] + +assert os.path.exists(input_path), 'Input model does not exist.' +assert not os.path.exists(output_path), 'Output filename already exists.' +assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' + +import torch +from share import * +from cldm.model import create_model + + +def get_node_name(name, parent_name): + if len(name) <= len(parent_name): + return False, '' + p = name[:len(parent_name)] + if p != parent_name: + return False, '' + return True, name[len(parent_name):] + + +model = create_model(config_path='./models/cldm_v21.yaml') + +pretrained_weights = torch.load(input_path) +if 'state_dict' in pretrained_weights: + pretrained_weights = pretrained_weights['state_dict'] + +scratch_dict = model.state_dict() + +target_dict = {} +for k in scratch_dict.keys(): + is_control, name = get_node_name(k, 'control_') + if is_control: + copy_k = 'model.diffusion_' + name + else: + copy_k = k + if copy_k in pretrained_weights: + target_dict[k] = pretrained_weights[copy_k].clone() + else: + target_dict[k] = scratch_dict[k].clone() + print(f'These weights are newly added: {k}') + +model.load_state_dict(target_dict, strict=True) +torch.save(model.state_dict(), output_path) +print('Done.') diff --git a/tool_transfer_control.py b/tool_transfer_control.py new file mode 100644 index 0000000000000000000000000000000000000000..b84442cc93f7f9c30cb7311b8675d9124a6e8ec9 --- /dev/null +++ b/tool_transfer_control.py @@ -0,0 +1,59 @@ +path_sd15 = './models/v1-5-pruned.ckpt' +path_sd15_with_control = './models/control_sd15_openpose.pth' +path_input = './models/anything-v3-full.safetensors' +path_output = './models/control_any3_openpose.pth' + + +import os + + +assert os.path.exists(path_sd15), 'Input path_sd15 does not exists!' +assert os.path.exists(path_sd15_with_control), 'Input path_sd15_with_control does not exists!' +assert os.path.exists(path_input), 'Input path_input does not exists!' +assert os.path.exists(os.path.dirname(path_output)), 'Output folder not exists!' + + +import torch +from share import * +from cldm.model import load_state_dict + + +sd15_state_dict = load_state_dict(path_sd15) +sd15_with_control_state_dict = load_state_dict(path_sd15_with_control) +input_state_dict = load_state_dict(path_input) + + +def get_node_name(name, parent_name): + if len(name) <= len(parent_name): + return False, '' + p = name[:len(parent_name)] + if p != parent_name: + return False, '' + return True, name[len(parent_name):] + + +keys = sd15_with_control_state_dict.keys() + +final_state_dict = {} +for key in keys: + is_first_stage, _ = get_node_name(key, 'first_stage_model') + is_cond_stage, _ = get_node_name(key, 'cond_stage_model') + if is_first_stage or is_cond_stage: + final_state_dict[key] = input_state_dict[key] + continue + p = sd15_with_control_state_dict[key] + is_control, node_name = get_node_name(key, 'control_') + if is_control: + sd15_key_name = 'model.diffusion_' + node_name + else: + sd15_key_name = key + if sd15_key_name in input_state_dict: + p_new = p + input_state_dict[sd15_key_name] - sd15_state_dict[sd15_key_name] + # print(f'Offset clone from [{sd15_key_name}] to [{key}]') + else: + p_new = p + # print(f'Direct clone to [{key}]') + final_state_dict[key] = p_new + +torch.save(final_state_dict, path_output) +print('Transferred model saved at ' + path_output) diff --git a/tutorial_dataset.py b/tutorial_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb327f981d10cf94e6a7f55f5b2b4497d3e7a9cb --- /dev/null +++ b/tutorial_dataset.py @@ -0,0 +1,39 @@ +import json +import cv2 +import numpy as np + +from torch.utils.data import Dataset + + +class MyDataset(Dataset): + def __init__(self): + self.data = [] + with open('./training/fill50k/prompt.json', 'rt') as f: + for line in f: + self.data.append(json.loads(line)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + source_filename = item['source'] + target_filename = item['target'] + prompt = item['prompt'] + + source = cv2.imread('./training/fill50k/' + source_filename) + target = cv2.imread('./training/fill50k/' + target_filename) + + # Do not forget that OpenCV read images in BGR order. + source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) + + # Normalize source images to [0, 1]. + source = source.astype(np.float32) / 255.0 + + # Normalize target images to [-1, 1]. + target = (target.astype(np.float32) / 127.5) - 1.0 + + return dict(jpg=target, txt=prompt, hint=source) + diff --git a/tutorial_dataset_test.py b/tutorial_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c4c355065d15c266a8b4e8c68dfcbe2b246730 --- /dev/null +++ b/tutorial_dataset_test.py @@ -0,0 +1,12 @@ +from tutorial_dataset import MyDataset + +dataset = MyDataset() +print(len(dataset)) + +item = dataset[1234] +jpg = item['jpg'] +txt = item['txt'] +hint = item['hint'] +print(txt) +print(jpg.shape) +print(hint.shape) diff --git a/tutorial_train.py b/tutorial_train.py new file mode 100644 index 0000000000000000000000000000000000000000..393d7addb164c32eff9c3d675e4f32fb555868f0 --- /dev/null +++ b/tutorial_train.py @@ -0,0 +1,35 @@ +from share import * + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from tutorial_dataset import MyDataset +from cldm.logger import ImageLogger +from cldm.model import create_model, load_state_dict + + +# Configs +resume_path = './models/control_sd15_ini.ckpt' +batch_size = 4 +logger_freq = 300 +learning_rate = 1e-5 +sd_locked = True +only_mid_control = False + + +# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. +model = create_model('./models/cldm_v15.yaml').cpu() +model.load_state_dict(load_state_dict(resume_path, location='cpu')) +model.learning_rate = learning_rate +model.sd_locked = sd_locked +model.only_mid_control = only_mid_control + + +# Misc +dataset = MyDataset() +dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True) +logger = ImageLogger(batch_frequency=logger_freq) +trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger]) + + +# Train! +trainer.fit(model, dataloader) diff --git a/tutorial_train_sd21.py b/tutorial_train_sd21.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbc148f9b1e90561f5a186cc0be94c911dd67cf --- /dev/null +++ b/tutorial_train_sd21.py @@ -0,0 +1,35 @@ +from share import * + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from tutorial_dataset import MyDataset +from cldm.logger import ImageLogger +from cldm.model import create_model, load_state_dict + + +# Configs +resume_path = './models/control_sd21_ini.ckpt' +batch_size = 4 +logger_freq = 300 +learning_rate = 1e-5 +sd_locked = True +only_mid_control = False + + +# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs. +model = create_model('./models/cldm_v21.yaml').cpu() +model.load_state_dict(load_state_dict(resume_path, location='cpu')) +model.learning_rate = learning_rate +model.sd_locked = sd_locked +model.only_mid_control = only_mid_control + + +# Misc +dataset = MyDataset() +dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True) +logger = ImageLogger(batch_frequency=logger_freq) +trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger]) + + +# Train! +trainer.fit(model, dataloader)